tensorflow(十四)神经网络中的学习率

mac2022-06-30  99

适当的设置动态学习率能够优化神经网络

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#载入数据集 batch_size = 100#每个批次的大小 n_batch = mnist.train.num_examples//batch_size#计算一共有多少个批次 x = tf.placeholder(tf.float32,[None,784])#定义两个变量 y = tf.placeholder(tf.float32,[None,10]) keep_prob = tf.placeholder(tf.float32) lr = tf.Variable(0.001,dtype=tf.float32) #构建一个简单的神经网络 W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1)) b1 = tf.Variable(tf.zeros([500])+0.1) L1 = tf.nn.tanh(tf.matmul(x,W1)+b1) L1_drop = tf.nn.dropout(L1,keep_prob) W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1)) b2 = tf.Variable(tf.zeros([300])+0.1) L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2) L2_drop = tf.nn.dropout(L2,keep_prob) W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1)) b3 = tf.Variable(tf.zeros([10])+0.1) prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3) #二次代价函数 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)) #使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量 init = tf.global_variables_initializer() #结果存放在布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置 #求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(51): sess.run(tf.assign(lr,0.001*(0.95**epoch))) for batch in range(n_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0}) learning_rate = sess.run(lr) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}) print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc) + ",Learning_rate" + str(learning_rate))

运行结果:

Iter0,Testing Accuracy0.9224,Learning_rate0.001 Iter1,Testing Accuracy0.9353,Learning_rate0.00095 Iter2,Testing Accuracy0.9428,Learning_rate0.0009025 Iter3,Testing Accuracy0.948,Learning_rate0.000857375 Iter4,Testing Accuracy0.9523,Learning_rate0.00081450626 Iter5,Testing Accuracy0.956,Learning_rate0.0007737809 Iter6,Testing Accuracy0.9593,Learning_rate0.0007350919 Iter7,Testing Accuracy0.9607,Learning_rate0.0006983373 Iter8,Testing Accuracy0.9626,Learning_rate0.0006634204 Iter9,Testing Accuracy0.9635,Learning_rate0.0006302494 Iter10,Testing Accuracy0.9656,Learning_rate0.0005987369 Iter11,Testing Accuracy0.9658,Learning_rate0.0005688001 Iter12,Testing Accuracy0.9679,Learning_rate0.0005403601 Iter13,Testing Accuracy0.9685,Learning_rate0.0005133421 Iter14,Testing Accuracy0.9699,Learning_rate0.000487675 Iter15,Testing Accuracy0.9701,Learning_rate0.00046329122 Iter16,Testing Accuracy0.9692,Learning_rate0.00044012666 Iter17,Testing Accuracy0.9718,Learning_rate0.00041812033 Iter18,Testing Accuracy0.9715,Learning_rate0.00039721432 Iter19,Testing Accuracy0.9714,Learning_rate0.0003773536 Iter20,Testing Accuracy0.9728,Learning_rate0.00035848594 Iter21,Testing Accuracy0.972,Learning_rate0.00034056162 Iter22,Testing Accuracy0.9721,Learning_rate0.00032353355 Iter23,Testing Accuracy0.9732,Learning_rate0.00030735688 Iter24,Testing Accuracy0.9732,Learning_rate0.000291989 Iter25,Testing Accuracy0.9739,Learning_rate0.00027738957 Iter26,Testing Accuracy0.9734,Learning_rate0.0002635201 Iter27,Testing Accuracy0.9742,Learning_rate0.00025034408 Iter28,Testing Accuracy0.9743,Learning_rate0.00023782688 Iter29,Testing Accuracy0.9741,Learning_rate0.00022593554 Iter30,Testing Accuracy0.9745,Learning_rate0.00021463877 Iter31,Testing Accuracy0.9742,Learning_rate0.00020390682 Iter32,Testing Accuracy0.9743,Learning_rate0.00019371149 Iter33,Testing Accuracy0.9744,Learning_rate0.0001840259 Iter34,Testing Accuracy0.9758,Learning_rate0.00017482461 Iter35,Testing Accuracy0.9753,Learning_rate0.00016608338 Iter36,Testing Accuracy0.9748,Learning_rate0.00015777921 Iter37,Testing Accuracy0.9747,Learning_rate0.00014989026 Iter38,Testing Accuracy0.9754,Learning_rate0.00014239574 Iter39,Testing Accuracy0.9753,Learning_rate0.00013527596 Iter40,Testing Accuracy0.976,Learning_rate0.00012851215 Iter41,Testing Accuracy0.9752,Learning_rate0.00012208655 Iter42,Testing Accuracy0.9755,Learning_rate0.00011598222 Iter43,Testing Accuracy0.9756,Learning_rate0.00011018311 Iter44,Testing Accuracy0.9762,Learning_rate0.000104673956 Iter45,Testing Accuracy0.976,Learning_rate9.944026e-05 Iter46,Testing Accuracy0.9759,Learning_rate9.446825e-05 Iter47,Testing Accuracy0.976,Learning_rate8.974483e-05 Iter48,Testing Accuracy0.9762,Learning_rate8.525759e-05 Iter49,Testing Accuracy0.9761,Learning_rate8.099471e-05 Iter50,Testing Accuracy0.9762,Learning_rate7.6944976e-05

最终准确率可以达到0.9762

最新回复(0)