1. 手写数字识别功能完善
本文将继续进行手写数字识别案例开发,主要针对于模型功能的完善进行案例演示,主要考虑从以下几个方面对上一篇文章介绍的案例进行完善。
- 考虑增加准确率计算
- 开启Tensorboard显示变量
- 模型保存与加载
- 模型预测并输出结果
准确率计算思路:
- 模型经过softmax层输出结果为10个概率值
- 找到最大概率值所在位置和真实值one-hot编码标签最大值所在位置
- 若两者位置相同,则预测正确
- 两者所在位置相同返回1
- 两个所在位置不一致返回0
准确率计算将用到如下函数:
- np.argmax():返回最大值所在位置
- tf.argmax(y_true, 1):可以查看最大真实值在列中所在位置(1表示按列求最大值位置)
- tf.argmax(y_predict, 1):求最大的预测值在列中所在位置
- tf.equal():可以进行判断两者是否一致,若一致返回true,若不一致返回false(返回布尔数据类型)
- tf.cast():数据类型转换,转换成想要的数据类型。因为tf.equal会返回布尔类型的长列表,我们想让其返回浮点型长列表,以便计算准确率,需要对数据类型进行转换。
2. 案例演示
接下来,我们将在上一篇文章搭建的模型基础上进行完善,为其计算每一轮训练的准确率,并加以输出。
全连接神经网络模型张量变化过程:x[batch,784]∗w[784,10]+bias=ypredict[batch,10]x[batch,784]∗w[784,10]+bias=ypredict[batch,10]
- 在优化损失步骤以后加入准确率计算
- 准确率计算先比较位置是否一致,返回布尔数据类型的列表
- 然后,将其转换为浮点数类型,再进行平均值计算
- 注意:还需要在会话中运行求平均值后的变量,才能使其有具体值
def full_connection():
"""
用全连接神经网络识别手写数字
"""
# 1. 准备数据
mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) # 784列:一张图片由784个像素组成,一次传入N张图片
y_true = tf.placeholder(dtype=tf.float32, shape=(None, 10)) # 10列
# 2. 构建模型
weights = tf.Variable(initial_value=tf.random_normal(shape=(784, 10)))
bias = tf.Variable(initial_value=tf.random_normal(shape=[10])) # 10个标量
y_predict = tf.matmul(x, weights) + bias
# 3. 构造损失函数
error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 交叉熵损失函数,再求平均值
# 4. 优化损失
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.02).minimize(error) # 梯度下降优化器,最小化损失
# 5. 准确率计算
# 5.1 比较输出的结果最大值所在位置和真实值最大值所在位置
# y_true的形状为(N, 10),有N个样本行,10列
equal_list = tf.equal(tf.argmax(y_true, 1),
tf.argmax(y_predict, 1))
# 5.2 求平均
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
# 初始化变量
init = tf.global_variables_initializer()
# 开启会话
with tf.Session() as sess:
sess.run(init)
image, label = mnist.train.next_batch(100)
print("训练之前的损失loss为:%f" % sess.run(error, feed_dict={x:image, y_true:label}))
# 开始训练
for i in range(500):
_, loss, accuracy_value = sess.run([optimizer,error, accuracy], feed_dict={x:image, y_true:label}) # optimizer是一个操作,不需要其返回值,用_(None)来接收
print("第%d次训练, 损失为%f, 准确率为%f" % (i+1, loss, accuracy_value))
return None
full_connection()
运行结果如下图所示:经过500轮训练,最终准确率大约为0.65,可以进一步尝试增大训练轮数,观察其准确率是否还可以上升,并且可以考虑适当调节学习率来提高准确率。
3. 总结
本案例使用一层全连接层完成了分类问题,注意是分类问题,我们在最后一层的输出需要加入softmax激活函数,以此来解决多分类问题。本项目中使用的损失函数为交叉熵损失函数,使用梯度下降优化器来优化损失。本文重点介绍的是准确率计算:通过比对真实值和预测值的最大概率所在位置是否一致(若位置一致返回1;不一致返回0),再根据返回结果求平均值,以此来计算出准确率。
本文正在参加「金石计划 . 瓜分6万现金大奖」