个人笔记,不喜勿喷。
一、流程简述
1、将网络状态设置为训练模式 2、将训练数据中的输入和标签变量化 3、利用网络向前计算,包含计算输出、计算损失值 4、反向传播调整参数,包含梯度清零、误差反向传播、梯度下降更新 5、将网络状态设置为测试模式,继续2-4步
二、伪代码
for 迭代次数 net = net.train()#设置为训练模式 for 样本批次训练 #将训练数据中的输入和标签变量化 im = Variable(im) label = Variable(label) #向前计算 output = net(im)#计算输出 loss = criterion(output, label)#计算损失值 #反向传播 optimizer.zero_grad() # 梯度清零 loss.backward() # 误差反向传播 optimizer.step() # 梯度下降(更新)