文献标识码:A
DOI:10.16157/j.issn.0258-7998.182249
中文引用格式:黄睿,陆许明,邬依林. 基于TensorFlow深度学习手写体数字识别及应用[J].电子技术应用,2018,44(10):6-10.
英文引用格式:Huang Rui,Lu Xuming,Wu Yilin. Handwriting digital recognition and application based on TensorFlow deep learning[J]. Application of Electronic Technique,2018,44(10):6-10.
0 引言
随着科技的发展,人工智能识别技术已广泛应用于各个领域,同时也推动着计算机的应用朝着智能化发展。一方面,以深度学习、神经网络为代表的人工智能模型获国内外学者的广泛关注;另一方面,人工智能与机器学习的系统开源,构建了开放的技术平台,促进人工智能研究的开发。本文基于TensorFlow深度学习框架,构建Softmax、CNN模型,并完成手写体数字的识别。
LECUN Y等提出了一种LeNet-5的多层神经网络用于识别0~9的手写体数字,该研究模型通过反向传播(Back Propagation,BP)算法进行学习,建立起CNN应用的最早模型[1-2]。随着人工智能图像识别的出现,CNN成为研究热点,近年来主要运用于图像分类[3]、目标检测[4]、目标跟踪[5]、文本识别[6]等方面,其中AlexNet[7]、GoogleNet[8]和ResNet[9]等算法取得了较大的成功。
本文基于Google第二代人工智能开源平台TensorFlow,结合深度学习框架,对Softmax回归算法和CNN模型进行对比验证,最后对训练的模型基于Android平台下进行应用。
1 TensorFlow简介
2015年11月9日,Google发布并开源第二代人工智能学习系统TensorFlow[10]。Tensor表示张量(由N维数组成),Flow(流)表示基于数据流图的计算,TensorFlow表示为将张量从流图的一端流动到另一端的计算。TensorFlow支持短期记忆网络(Long Short Term Memory Networks,LSTMN)、循环神经网络(Recurrent Neural Networks,RNN)和卷积神经网络(CNN)等深度神经网络模型。TensorFlow的基本架构如图1所示。
由图1可知,TensorFlow的基本架构可分为前端和后端。前端:基于支持多语言的编程环境,通过调用系统API来访问后端的编程模型。后端:提供运行环境,由分布式运行环境、内核、网络层和设备层组成。
2 Softmax回归
Softmax回归算法能将二分类的Logistic回归问题扩展至多分类。假设回归模型的样本由K个类组成,共有m个,则训练集可由式(1)表示:
式中,x(i)∈R(n+1),y(i)∈{1,2,…,K},n+1为特征向量x的维度。对于给定的输入值x,输出的K个估计概率由式(2)表示:
对参数θ1,θ2,…,θk进行梯度下降,得到Softmax回归模型,在TensorFlow中的实现如图2所示。
对图2进行矩阵表达,可得式(5):
将测试集数据代入式(5),并计算所属类别的概率,则概率最大的类别即为预测结果。
3 CNN
卷积神经网络(CNN)是一种前馈神经网络,通常包含数据输入层、卷积计算层、ReLU激励层、池化层、全连接层等,是由卷积运算来代替传统矩阵乘法运算的神经网络。CNN常用于图像的数据处理,常用的LenNet-5神经网络模型图如图3所示。
该模型由2个卷积层、2个抽样层(池化层)、3个全连接层组成。
3.1 卷积层
卷积层是通过一个可调参数的卷积核与上一层特征图进行滑动卷积运算,再加上一个偏置量得到一个净输出,然后调用激活函数得出卷积结果,通过对全图的滑动卷积运算输出新的特征图,如式(6)~式(7)所示:
3.2 抽样层
抽样层是将输入的特征图用n×n窗口划分成多个不重叠的区域,然后对每个区域计算出最大值或者均值,使图像缩小了n倍,最后加上偏置量通过激活函数得到抽样数据。其中,最大值法、均值法及输出函数如式(8)~式(10)所示:
3.3 全连接输出层
全连接层则是通过提取的特征参数对原始图片进行分类。常用的分类方法如式(11)所示:
4 实验分析
本文基于TensorFlow深度学习框架,数据源使用MNIST数据集,分别采用Softmax回归算法和CNN深度学习进行模型训练,然后将训练的模型进行对比验证,并在Android平台上进行应用。
4.1 MNIST数据集
MNIST数据集包含60 000行的训练数据集(train-images-idx3)和10 000行的测试数据集(test-images-idx3)。每个样本都有唯一对应的标签(label),用于描述样本的数字,每张图片包含28×28个像素点,如图4所示。
由图4可知,每一幅样本图片由28×28个像素点组成,可由一个长度为784的向量表示。MNIST的训练数据集可转换成[60 000,784]的张量,其中,第一个维度数据用于表示索引图片,第二个维度数据用于表示每张图片的像素点。而样本对应的标签(label)是介于0到9的数字,可由独热编码(one-hot Encoding)进行表示。一个独热编码除了某一位数字是1以外,其余维度数字都是0,如标签0表示为[1,0,0,0,0,0,0,0,0,0],所以,样本标签为[60 000,10]的张量。
4.2 Softmax模型实现
根据式(5),可以将Softmax模型分解为矩阵基本运算和Softmax调用,该模型实现方式如下:(1)使用符号变量创建可交互的操作单元;(2)创建权重值和偏量;(3)根据式(5),实现Softmax回归。
4.3 CNN模型实现
结合LenNet-5神经网络模型,基于TensorFlow深度学习模型实现方式如下:
(1)初始化权重和偏置;
(2)创建卷积和池化模板;
(3)进行两次的卷积、池化;
(4)进行全连接输出;
(5)Softmax回归。
4.4 评估指标
采用常用的成本函数“交叉熵”(cross-entropy),如式(12)所示:
4.5 模型检验
预测结果检验方法如下:
(1)将训练后的模型进行保存;
(2)输入测试样本进行标签预测;
(3)调用tf.argmax函数获取预测的标签值;
(4)与实际标签值进行匹配,最后计算识别率。
根据上述步骤,分别采用Softmax模型和卷积神经网络对手写数字0~9的识别数量、识别率分别如图5、表1所示。
根据表1的模型预测结果可知,Softmax模型对数字1的识别率为97.9%,识别率最高。对数字3和数字8的识别率相对较小,分别是84.9%、87.7%。Softmax模型对手写体数字0~9的整体识别率达91.57%。
结合图5和表1可知,基于CNN模型的整体识别率高于Softmax模型,其中对数字3的识别率提高了14.7%,对数字1的识别率只提高了1.7%。基于深度学习CNN模型对手写体数字0~9的整体识别率高达99.17%,比Softmax模型整体提高了7.6%。
4.6 模型应用
通过模型的对比验证可知,基于深度学习CNN的识别率优于Softmax模型。现将训练好的模型移植到Android平台,进行跨平台应用,实现方式如下。
(1)UI设计
用1个Bitmap控件显示用户手写触屏的轨迹,并将2个Button控件分别用于数字识别和清屏。
(2)TensorFlow引用
首先编译需要调用的TensorFlow的jar包和so文件。其次将训练好的模型(.pb)导入到Android工程。
(3)接口实现
①接口定义及初始化:
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
②接口的调用:
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);
③获取预测结果
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs);
通过上述步骤即可完成基于Android平台环境的搭建及应用,首先利用Android的触屏功能捕获并记录手写轨迹,手写完成后单击识别按钮,系统将调用模型进行识别,并将识别结果输出到用户界面。识别完成后,单击清除按钮,循环上述操作步骤可进行手写数字的再次识别,部分手写数字的识别效果如图6所示。
由图6可知,在Android平台上完成了基于TensorFlow深度学习手写数字的识别,并且采用CNN的训练模型有较好的识别效果,实现了TensorFlow训练模型跨平台的应用。
5 结论
本文基于TensorFlow深度学习框架,采用Softmax回归和CNN等算法进行手写体数字训练,并将模型移植到Android平台,进行跨平台应用。实验数据表明,基于Softmax回归模型的识别率为91.57%,基于CNN模型的识别率高达99.17%。表明基于深度学习的手写体数字识别在人工智能识别方面具有一定的参考意义。
参考文献
[1] HUBEL D H,WIESEL T N.Receptive fields and functional architecture of monkey striate cortex[J].Journal of Physiology,1968,195(1):215-243.
[2] LECUN Y,BOTTOU L,BENGIO Y,et al.Gradient-based learning applied to document recognition[J].Proceedings of the IEEE,1998,86(11):2278-2324.
[3] ZEILER M D,FERGUS R.Visualizing and understanding convolutional networks[J].arXiv:1311.2901[cs.CV].
[4] 贺康建,周冬明,聂仁灿,等.基于局部控制核的彩色图像目标检测方法[J].电子技术应用,2016,42(12):89-92.
[5] LI H,LI Y,PORIKLI F.DeepTrack:learning discriminative feature representations online for robust visual tracking[J].IEEE Transactions on Image Processing,2015,25(4):1834-1848.
[6] GOODFELLOW I J,BULATOV Y,IBARZ J,et al.Multi-digit number recognition from street view imagery using deep convolutional neural networks[J].arXiv:1312.6082[cs.CV].
[7] KRIZHEVSKY A,SUTSKEVER I,HINTON G E.ImageNet classifycation with deep convolutional neural networks[C].International Conference on Neural Information Processing Systems.Curran Associates Inc.,2012:1097-1105.
[8] SZEGEDY C,LIU W,JIA Y,et al.Going deeper with convoluteons[C].IEEE Conference on Computer Vision and Pattern Recognition.IEEE,2015:1-9.
[9] HE K,ZHANG X,REN S,et al.Deep residual learning for image recognition[J].arXiv:1512.03385[cs.CV].
[10] ABADI M,AGARWAL A,BARHAM P,et al.TensorFlow:large-scale machine learning on heterogeneous distributed systems[J].arXiv:1603.04467[cs.DC].
作者信息:
黄 睿,陆许明,邬依林
(广东第二师范学院 计算机科学系,广东 广州510303)