Pytorch实现断点续训

mac2026-01-15  5

问题

在进行训练网络模型时,经常会遇到服务器中断或其他原因导致正在训练的模型中断,如果没有保存模型,就要重新训练,费时费力。这种情况怎么解决呢?可以继续原先的模型训练程度继续训练吗?包括weights,epochs,lr,loss等等…

解决

1.通过torch.save()方法保存模型,包括model,loss,epoch,IoU。可以设置每隔几个epochs保存一次。 state = { "net":model.module.state_dict(), "loss":val_loss, "epoch":epoch, "iou":lb, } if not os.path.isdir("checkpoint"): os.mkdir("checkpoint") torch.save(state,'./checkpoint/ckpt_best_%s.pth' % (str(fold))) 2.通过args.resume()方法判断是否可以加载保存的模型继续训练 if args.resume == True: checkpoint = torch.load('./checkpoint/ckpt_best_%s.pth' % (str(fold))) model.load_state_dict(checkpoint['net']) best_loss = checkpoint['loss'] epoch = checkpoint['epoch'] + 1 best_iou = checkpoint['iou']
最新回复(0)