博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
阅读量:6676 次
发布时间:2019-06-25

本文共 5727 字,大约阅读时间需要 19 分钟。

主要内容:

1.基于CNN的mnist手写数字识别(详细代码注释)
2.该实现中的函数总结

平台:

1.windows 10 64位
2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个Python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)
3.TensorFlow1.1.0

CNN的介绍可以看:

这里用的CNN结构是: 输入层-C1-P1-C2-P2-FC1-Dropout-FC2-softmax(输出层)

代码:

# -*- coding: utf-8 -*-"""Created on Mon Jun 12 16:36:43 2017@author: ASUS"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data/', one_hot = True) #  mnist是一个tensorflow内部的变量sess = tf.InteractiveSession()  # 创建 一个会话# 权值初始化函数,用截断的正态分布,两倍标准差之外的被截断def weight_variable(shape):    initial = tf.truncated_normal(shape, stddev = 0.1)    return tf.Variable(initial)#  偏置初始化函数,偏置初始为0.1def bias_variable(shape):    initial = tf.constant(0.1, shape = shape)    return tf.Variable(initial)# 定义卷积方式,步长是1111,padding的SAME是使得特征图与输入图大小一致def conv2d(x,W):    return tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding ='SAME')# 定义池化方式,采用最大池化def max_pool_2x2(x):        return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1],padding='SAME')# 定义占位符x = tf.placeholder(tf.float32, [None, 784])    y_ = tf.placeholder(tf.float32, [None, 10])# 1D向量(1,784)转2D(28,28)x_image = tf.reshape(x, [-1,28,28,1])  # -1 表示样本数量不固定#---------------第1/4步:定义算法公式-------------------# 定义 卷积层 conv1W_conv1 = weight_variable([5, 5, 1, 32])b_conv1 = bias_variable([32])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)             h_pool1 = max_pool_2x2(h_conv1)# 定义 卷积层 conv2W_conv2 = weight_variable([5, 5, 32, 64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)             h_pool2 = max_pool_2x2(h_conv2)#定义 全连接层 fc1W_fc1 = weight_variable([7*7*64, 1024])b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])    # 将tensor拉成向量h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)# 定义Dropout层keep_prob = tf.placeholder(tf.float32)h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)# 定义 Softmax层W_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)#---------------第2/4步:定义loss和优化器-------------------# 定义loss 和 参数优化器cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices = [1]))  # -sigma y_ * log(y)train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)# 准确率验证correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#---------------第3/4步:训练步骤-------------------# 训练tf.global_variables_initializer().run()for i in range(2000):    batch = mnist.train.next_batch(100)    if i%100 ==0:        train_accuracy = accuracy.eval(feed_dict= {x: batch[0], y_: batch[1], keep_prob:1.0})        print('step %d, training accuracy %g' %(i, train_accuracy))    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5})#---------------第4/4步:测试集上评估模型-------------------# 在验证阶段可能出先一个问题就是GPU内存不够的问题,这里是整个test输入,进行计算# GPU内存不够大的话,就会出错(我的GTX 960m,2G)顶不住啊! 所以要分batch的进行# 这里是输入整个test集的# print('test accuracy %g ' % accuracy.eval(feed_dict = {
# x: mnist.test.images,# y_:mnist.test.labels, keep_prob:1.0}))# 这里是分batch验证的accuracy_sum = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))good = 0total = 0for i in range(2): testSet = mnist.test.next_batch(100) if i ==1 : print(testSet[0].shape[0]) good += accuracy_sum.eval(feed_dict = { x: testSet[0], y_: testSet[1], keep_prob: 1.0}) total += testSet[0].shape[0] # testSet[0].shape[0] 是本batch有的样本数量print("test accuracy %g"%(good/total))

这里面出了个小问题就是在测试阶段,书上直接把整个test集放进去了,而我的GPU内存不够大,导致出错。所以这里采用了分batch的方法进行测试,大家可以试一下整个test集放进去测试会出现什么情况。

**

函数总结(续上篇)

**:

1. sess = tf.InteractiveSession() 将sess注册为默认的session
2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)
3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
4. tf.matmul() 矩阵相乘函数
5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和
6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别
7. tf.cast() 用于数据类型转换
————————————–我是分割线(一)———————————–

tf.random_uniform 生成均匀分布的随机数

tf.train.AdamOptimizer() 创建优化器,优化方法为Adam(adaptive moment estimation,Adam优化方法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率)
tf.placeholder “占位符”,只要是对网络的输入,都需要用这个函数这个进行“初始化”
tf.random_normal 生成正态分布
tf.add 和 tf.matmul 数据的相加 、相乘
tf.reduce_sum 缩减维度的求和
tf.pow 求幂函数
tf.subtract 数据的相减
tf.global_variables_initializer 定义全局参数初始化
tf.Session 创建会话.
tf.Variable 创建变量,是用来存储模型参数的变量。是有别于模型的输入数据的
tf.train.AdamOptimizer (learning_rate = 0.001) 采用Adam进行优化,学习率为 0.001
————————————–我是分割线(二)———————————–
1. hidden1_drop = tf.nn.dropout(hidden1, keep_prob) 给 hindden1层增加Droput,返回新的层hidden1_drop,keep_prob是 Droput的比例
2. mnist.train.next_batch() 来详细讲讲 这个函数。一句话概括就是,打乱样本顺序,然后按顺序读取batch_size 个样本 进行返回。
具体看代码及其注释,首先要找到函数定义,在tensorflow\contrib\learn\python\learn\datasets 下的mnist.py
————————————–我是分割线(三)———————————–
1. tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding =’SAME’)对于这个函数主要理解 strides和padding,首先明确,x是输入,W是卷积核,并且它们的维数都是4(发现strides里有4个元素没,没错!就是一一对应的)
先说一下卷积核W也是一个四维张量,各维度表示的信息是:[filter_height, filter_width, in_channels, out_channels]

输入x,x是一个四维张量 ,各维度表示的信息是:[batch, in_height, in_width, in_channels]

strides里的每个元素就是对应输入x的四个维度的步长,因为第2,3维是图像的长和宽,所以平时用的strides就在这里设置,而第1,4维一般不用到,所以是1

padding只有两种取值方式,一个是 padding=[‘VALID’] 一个是padding=[‘SAME’]

valid:采用丢弃的方式,只要移动一步时,最右边有超出,则这一步不移动,并且剩余的进行丢弃。如下图,图片长13,卷积核长6,步长是5,当移动一步之后,已经卷积核6-11,再移动一步,已经没有足够的像素点了,所以就不能移动,因此 12,13被丢弃。
same:顾名思义,就是保持输入的大小不变,方法是在图像边缘处填充全0的像素
这里写图片描述

转载于:https://www.cnblogs.com/TensorSense/p/7413314.html

你可能感兴趣的文章
阿里巴巴上线静态开源站点搭建工具 Docsite
查看>>
如何使用Data Lake Analytics创建分区表
查看>>
您对TOP Server的德语、中文和日语语言支持了解吗?(一)
查看>>
基于 Spring Boot 和 Spring Cloud 实现微服务架构
查看>>
Qt之添加菜单项&状态栏
查看>>
负载均衡在分布式架构中是怎么玩起来的?
查看>>
Java程序员在工作的同时应该具备什么样的能力?
查看>>
Dubbo深入分析之Cluster层
查看>>
分析Padavan源代码,二
查看>>
WordPress的WPML外挂出问题恐出现安全漏洞
查看>>
Django 调试技巧
查看>>
Spring Boot和thymeleaf , freemarker , jsp三个前端模块的运用
查看>>
phalcon-入门篇3(优美的URL与Config)
查看>>
单表60亿记录等大数据场景的MySQL优化和运维之道
查看>>
sql学习笔记
查看>>
maven编译时出现There are test failures
查看>>
SpringBoot | 第三十一章:MongoDB的集成和使用
查看>>
网络学习笔记2
查看>>
JPA--多对多关系
查看>>
配置sharepoint 2010错误:Microsoft.SharePoint.Upgrad...
查看>>