欢迎光临

我们一直在努力
当前位置:首页 > 科技数码 >

mnist:我们建了个模型,搞定了 MNIST 数字识别任务

日期: 来源:收集编辑:
mnist

(公众号:(公众号:))按:本文为字幕组编译的技术博客,原标题 A simple 2D CNN for MNIST digit recognition ,作者为Sambit Mahapatra。

翻译 | 王祎霍雷刚 整理 | MY

对于图像分类任务,当前最先进的架构是卷积神经网络 (CNNs) 。无论是面部识别、自动驾驶还是目标检测,CNN 得到广泛使用。在本文中,针对著名的 MNIST 数字识别任务,我们设计了一个以 tensorflow 为后台技术、基于 keras 的简单 2D 卷积神经网络 (CNN) 模型。整个工作流程如下:

1. 准备数据

2. 创建模型并编译

3. 训练模型并评估

4. 将模型存盘以便下次使用

1. 准备数据

数据集就使用上文所提到的 MNIST 数据集。MNIST 数据集 ( Modified National Institute of Standards and Technoloy 数据集) 是一个大型的手写数字(0 到 9)数据集。该数据集包含 大小为 28x28 的图片 7 万张,其中 6 万张训练图片、1 万张测试图片。第一步,加载数据集,这一步可以很容易地通过 keras api 来实现。

其中,X_train 包含 6 万张 大小为 28x28 的训练图片,y_train 包含这些图片对应的标签。与之类似,X_test 包含了 1 万张大小为 28x28 的测试图片,y_test 为其对应的标签。我们将一部分训练数据可视化一下,来对深度学习模型的目标有一个认识吧。

如上所示,左上角图为「5」的图片数据被存在 X_train[0] 中,y_train[0] 中存储其对应的标签「5」。我们的深度学习模型应该能够仅仅通过手写图片预测实际写下的数字。 现在为了准备数据,我们需要对这些图片做一些诸如调整大小、像素值归一化之类的处理。

对图片数据做了必要的处理之后,需要将 y_train 和 y_test 标签数据进行转换,转换成分类的格式。例如,模型构建时,3 应该被转换成向量 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。

创建模型并编译

数据加载进模型之后,我们需要定义模型结构,并优化函数、损失函数和性能指标。

接下来定义的架构为 2 个卷积层,分别在每个卷积层后接续一个池化层,一个全连接层和一个 softmax 层。在每一层卷积层上都会使用多个滤波器来提取不同类型的特征。直观的解释是,第一个滤波器有助于检测图片中的直线,第二个滤波器有助于检测图片中的圆形,等等。关于每一层技术实现的解释,将会在后续的帖子中进行讲解。如果想要更好的理解每一层的含义,可以参考 http://cs231n.github.io/convolutional-networks/ 。

在最大池化和全连接层之后,在我们的模型中引入 dropout 来进行正则化,用以消除模型的过拟合问题。

确定模型架构之后需要对模型进行编译。这是项多类别的分类问题,因此我们需要使用 categorical_crossentropy 作为损失函数。由于所有的标签都带有相似的权重,我们更喜欢使用精确度作为性能指标。AdaDelta 是一个很常用的梯度下降方法,我们使用这个方法来优化模型参数。

训练模型并评估

在定义模型架构和编译模型之后,要使用训练集去训练模型,使得模型可以识别手写数字。这里,我们将使用 X_train 和 y_train 来拟合模型。

其中,一个 epoch 表示一次全量训练样例的前向和后向传播。batch_size 就是在一次前向/后向传播过程用到的训练样例的数量。训练输出结果如下:

现在,我们来评估训练得到模型的性能。

测试准确率达到了 99%+,这意味着这个预测模型训练的很成功。如果查看整个训练日志,就会发现随着 epoch 次数的增多,模型在训练数据、测试数据上的损失和准确率逐渐收敛,最终趋于稳定。

将模型存盘以便下次使用

现在需要将训练过的模型进行序列化。模型的架构或者结构保存在 json 文件,权重保存在 hdf 5 文件。

模型被保存后,可以被重用,也可以很方便地移植到其它环境中使用。在以后的帖子中,我们将会演示如何在生产环境中部署这个模型。

享受深度学习吧!

参考文献:

Guide to the Sequential model - Keras DocumentationGetting started with the Keras Sequential modelkeras.io

CS231n Convolutional Neural Networks for Visual RecognitionCourse materials and notes for Stanford class CS231n: Convolutional Neural Networks for Visual Recognition.cs231n.github.io

原文链接:https://towardsdatascience.com/a-simple-2d-cnn-for-mnist-digit-recognition-a998dbc1e79a

我们建了个模型,搞定了 MNIST 数字识别任务

相关阅读

热门文章

最新文章

  • 麻将游戏的三个赢牌方法

  • 打麻将是运气与牌技的结合,即使有时运气不好,运气是一时不是一世,高手都往往通过自己过人的技术来赢得牌
  • 麻将大师指导技巧分享

  • 这些都是一些麻将高手总结出来的麻将口诀,相信对提高大家的麻将技术会有一定的帮助的,在实战中也要多观察
  • 为什么有人打麻将总是赢牌?

  • 麻将就是一种游戏,一种娱乐生活的方式,大家一定要以一种平稳的心态去对待它,不以物喜、不以己悲,不要太
  • 麻将的起源

  • 中国作为一个多文化大国,其内涵之丰富。当然,除了自然风景文化,在几千年的人类酝酿中,"赌"文化也不断