问题
在进行训练网络模型时,经常会遇到服务器中断或其他原因导致正在训练的模型中断,如果没有保存模型,就要重新训练,费时费力。这种情况怎么解决呢?可以继续原先的模型训练程度继续训练吗?包括weights,epochs,lr,loss等等…
解决
1.通过torch.save()方法保存模型,包括model,loss,epoch,IoU。可以设置每隔几个epochs保存一次。
state
= {
"net":model
.module
.state_dict
(),
"loss":val_loss
,
"epoch":epoch
,
"iou":lb
,
}
if not os
.path
.isdir
("checkpoint"):
os
.mkdir
("checkpoint")
torch
.save
(state
,'./checkpoint/ckpt_best_%s.pth' % (str(fold
)))
2.通过args.resume()方法判断是否可以加载保存的模型继续训练
if args
.resume
== True:
checkpoint
= torch
.load
('./checkpoint/ckpt_best_%s.pth' % (str(fold
)))
model
.load_state_dict
(checkpoint
['net'])
best_loss
= checkpoint
['loss']
epoch
= checkpoint
['epoch'] + 1
best_iou
= checkpoint
['iou']