博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
计算机视觉—CNN识别手写数字(11)
阅读量:6411 次
发布时间:2019-06-23

本文共 3697 字,大约阅读时间需要 12 分钟。

一、加载MNIST数据

TensorFlow已经准备了一个脚本来自动下载和导入MNIST数据集。它会自动创建一个'MNIST_data'的目录来存储数据。

import tensorflow as tfimport numpy as npimport random from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data',one_hot=True)# one_hot介绍 :https://blog.csdn.net/lanhaier0591/article/details/78702558复制代码

这里,mnist是一个轻量级的类。它以Numpy数组的形式存储着训练、校验和测试数据集。同时提供了一个函数,用于在迭代中获得minibatch,后面我们将会用到。

原文链接:http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html

二、输入与占位符

placeholder_inputs()函数将生成两个tf.placeholder操作,定义传入图表中的shape参数,shape参数中包括batch_size值,后续还会将实际的训练用例传入图表。

imageInput = tf.placeholder(tf.float32,[None,784]) # 训练图像labeInput = tf.placeholder(tf.float32,[None,10]) # 训练标签复制代码

三、构建一个多层卷积网络

1、权重初始化

reshape(tensor, shape, name=None)

参数

  • tensor,被调整维度的张量
  • shape,要调整为的形状
imageInputReshape = tf.reshape(imageInput,[-1,28,28,1])# 2维转变为4维复制代码

tf.truncated_normal(shape, mean, stddev)

参数

  • shape表示生成张量的维度,
  • mean是均值,
  • stddev是标准差。
w0 = tf.Variable(tf.truncated_normal([5,5,1,32],stddev = 0.1))# 求标准差b0 = tf.Variable(tf.constant(0.1,shape=[32]))# 生成一个32维的张量复制代码

2、激励函数+卷积运算

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

参数

  • input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,要求类型为float32和float64其中之一
  • filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,有一个地方需要注意,第三维in_channels,就是参数input的第四维
  • strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4
  • padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式
  • use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true

输出:

  • 结果返回一个Tensor,这个输出,就是我们常说的feature map,shape仍然是[batch, height, width, channels]这种形式。

tf.nn.max_pool(value, ksize, strides, padding, name=None)

参数

  • value:池化的输入,一般池化层接在卷积层的后面,所以输出通常为feature map。feature map依旧是[batch, in_height, in_width, in_channels]这样的参数。
  • ksize:池化窗口的大小,参数为四维向量,通常取[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1。ps:估计面tf.nn.conv2d中stries的四个取值也有 相同的意思。
  • stries:步长,同样是一个四维向量。
  • padding:填充方式同样只有两种不重复了。
layer1 = tf.nn.relu(tf.nn.conv2d(imageInputReshape,w0,strides=[1,1,1,1],padding='SAME')+b0)# layer1:激励函数+卷积运算# imageInputReshape : M*28*28*1  w0:5,5,1,32  # layer1:M*28*28*32复制代码

3、池化

layer1_pool = tf.nn.max_pool(layer1,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')# pool采样:数据量减少很多M*28*28*32 => M*7*7*32复制代码

4、激励函数+乘加运算

# layer2 out : softmax(激励函数 + 乘加运算)w1 = tf.Variable(tf.truncated_normal([7*7*32,1024],stddev=0.1))b1 = tf.Variable(tf.constant(0.1,shape=[1024]))h_reshape = tf.reshape(layer1_pool,[-1,7*7*32])h1 = tf.nn.relu(tf.matmul(h_reshape,w1)+b1)# [N*7*7*32]  [7*7*32,1024] = N*1024复制代码

5、输出层

最后,我们添加一个softmax层

w2 = tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))b2 = tf.Variable(tf.constant(0.1,shape=[10]))pred = tf.nn.softmax(tf.matmul(h1,w2)+b2)# N*1024  1024*10 = N*10复制代码

6、损失函数

loss0 = labeInput*tf.log(pred)loss1 = 0for m in range(0,500):    for n in range(0,10):        loss1 = loss1 - loss0[m,n]loss = loss1/500复制代码

7、训练和评估模型

train = tf.train.GradientDescentOptimizer(0.01).minimize(loss)# 让误差尽可能缩小with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    for i in range(100):        images,labels = mnist.train.next_batch(500)        sess.run(train,feed_dict={imageInput:images,labeInput:labels})                pred_test = sess.run(pred,feed_dict={imageInput:mnist.test.images,labeInput:labels})        acc = tf.equal(tf.arg_max(pred_test,1),tf.arg_max(mnist.test.labels,1))        acc_float = tf.reduce_mean(tf.cast(acc,tf.float32))        acc_result = sess.run(acc_float,feed_dict={imageInput:mnist.test.images,labeInput:mnist.test.labels})        print(acc_result)复制代码
你可能感兴趣的文章
企业应用架构模式阅读笔记 - Martin Fowler
查看>>
PostgreSQL缓存
查看>>
iOS开发技巧 - 使用和定制开关控件(UISwitch)
查看>>
音乐闹钟
查看>>
JQuery模板插件jquery.tmpl-动态ajax扩展
查看>>
QT小滑块
查看>>
iis7.5 发布mvc出错的解决办法
查看>>
职称英语
查看>>
用JavaScript生成Android SDK的下载地址(4)——按“API Level”分类
查看>>
SQL Server 自动增长清零
查看>>
多核与云计算
查看>>
C++中的头文件和源文件
查看>>
SQLite在Android中使用
查看>>
Spring 3 MVC And RSS Feed Example
查看>>
【转】Linux 下修改Tomcat使用的JVM内存大小
查看>>
【uTenux实验】事件标志
查看>>
利用Python进行数据分析(15) pandas基础: 字符串操作
查看>>
busybox inetd tftpd
查看>>
函数可重入性及编写规范
查看>>
Scribe应用实例
查看>>