手写数字识别功能完善

释放双眼,带上耳机,听听看~!
本篇教程将介绍如何在深度学习模型中完善手写数字识别功能,并对准确率进行计算和优化。通过TensorFlow搭建全连接神经网络,解决数字识别问题。

1. 手写数字识别功能完善

本文将继续进行手写数字识别案例开发,主要针对于模型功能的完善进行案例演示,主要考虑从以下几个方面对上一篇文章介绍的案例进行完善。

  • 考虑增加准确率计算
  • 开启Tensorboard显示变量
  • 模型保存与加载
  • 模型预测并输出结果

准确率计算思路:

  • 模型经过softmax层输出结果为10个概率值
  • 找到最大概率值所在位置和真实值one-hot编码标签最大值所在位置
  • 若两者位置相同,则预测正确
    • 两者所在位置相同返回1
    • 两个所在位置不一致返回0

准确率计算将用到如下函数:

  • np.argmax():返回最大值所在位置
  • tf.argmax(y_true, 1):可以查看最大真实值在列中所在位置(1表示按列求最大值位置)
  • tf.argmax(y_predict, 1):求最大的预测值在列中所在位置
  • tf.equal():可以进行判断两者是否一致,若一致返回true,若不一致返回false(返回布尔数据类型)
  • tf.cast():数据类型转换,转换成想要的数据类型。因为tf.equal会返回布尔类型的长列表,我们想让其返回浮点型长列表,以便计算准确率,需要对数据类型进行转换。

2. 案例演示

接下来,我们将在上一篇文章搭建的模型基础上进行完善,为其计算每一轮训练的准确率,并加以输出。

全连接神经网络模型张量变化过程:x[batch,784]∗w[784,10]+bias=ypredict[batch,10]x[batch,784]∗w[784,10]+bias=ypredict[batch,10]

  • 在优化损失步骤以后加入准确率计算
  • 准确率计算先比较位置是否一致,返回布尔数据类型的列表
  • 然后,将其转换为浮点数类型,再进行平均值计算
  • 注意:还需要在会话中运行求平均值后的变量,才能使其有具体值
def full_connection():
    """
    用全连接神经网络识别手写数字
    """
    # 1. 准备数据
    mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
    x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) # 784列:一张图片由784个像素组成,一次传入N张图片
    y_true = tf.placeholder(dtype=tf.float32, shape=(None, 10)) # 10列
    
    # 2. 构建模型
    weights = tf.Variable(initial_value=tf.random_normal(shape=(784, 10)))
    bias = tf.Variable(initial_value=tf.random_normal(shape=[10])) # 10个标量
    y_predict = tf.matmul(x, weights) + bias
    
    # 3. 构造损失函数
    error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 交叉熵损失函数,再求平均值
    
    # 4. 优化损失
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.02).minimize(error) # 梯度下降优化器,最小化损失
    
    # 5. 准确率计算
    # 5.1 比较输出的结果最大值所在位置和真实值最大值所在位置
    # y_true的形状为(N, 10),有N个样本行,10列
    equal_list = tf.equal(tf.argmax(y_true, 1),
                         tf.argmax(y_predict, 1))
    
    # 5.2 求平均
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    
    # 初始化变量
    init = tf.global_variables_initializer()
    
    # 开启会话
    with tf.Session() as sess:
        sess.run(init)
        image, label = mnist.train.next_batch(100)
        
        
        print("训练之前的损失loss为:%f" % sess.run(error, feed_dict={x:image, y_true:label}))
        
        # 开始训练
        for i in range(500):
            _, loss, accuracy_value = sess.run([optimizer,error, accuracy], feed_dict={x:image, y_true:label}) # optimizer是一个操作,不需要其返回值,用_(None)来接收
            print("第%d次训练, 损失为%f, 准确率为%f" % (i+1, loss, accuracy_value))
            
    return None
        
full_connection()      

运行结果如下图所示:经过500轮训练,最终准确率大约为0.65,可以进一步尝试增大训练轮数,观察其准确率是否还可以上升,并且可以考虑适当调节学习率来提高准确率。

手写数字识别功能完善

3. 总结

本案例使用一层全连接层完成了分类问题,注意是分类问题,我们在最后一层的输出需要加入softmax激活函数,以此来解决多分类问题。本项目中使用的损失函数为交叉熵损失函数,使用梯度下降优化器来优化损失。本文重点介绍的是准确率计算:通过比对真实值和预测值的最大概率所在位置是否一致(若位置一致返回1;不一致返回0),再根据返回结果求平均值,以此来计算出准确率。

本文正在参加「金石计划 . 瓜分6万现金大奖」

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

本地Stable Diffusion AI 绘画漫画写实风

2023-12-17 10:14:14

AI教程

Gradio实践教程:从RGB转灰度到文本分类

2023-12-17 10:31:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索