21个项目玩转深度学习:基于TensorFlow的实践详解
上QQ阅读APP看本书,新人免费读10天
设备和账号都新为新人

1.1 MNIST数据集

1.1.1 简介

首先介绍MNIST数据集。如图1-1所示,MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0~9,共10个阿拉伯数字。

图1-1 MNIST数据集图片示例

原始的MNIST数据库一共包含下面4个文件,见表1-1。

表1-1 原始的MNIST数据集包含的文件

在表1-1中,图像数据是指很多张手写字符的图像,图像的标签是指每一张图像实际对应的数字是几,也就是说,在MNIST数据集中的每一张图像都事先标明了对应的数字。

在MNIST数据集中有两类图像:一类是训练图像(对应文件train-images-idx3-ubyte.gz和train-labels-idx1-ubyte.gz),另一类是测试图像(对应文件t10k-images-idx3-ubyte.gz和t10k-labels-idx1-ubyte.gz)。训练图像一共有60000张,供研究人员训练出合适的模型。测试图像一共有10000张,供研究人员测试训练的模型的性能。在TensorFlow中,可以使用下面的Python代码下载MNIST数据(在随书附赠的代码中,该代码对应的文件是donwload.py)。

    # coding:utf-8
    # 从tensorflow.examples.tutorials.mnist引入模块
    # 这是TensorFlow为了教学MNIST而提前编制的程序
    from tensorflow.examples.tutorials.mnist import input_data
    # 从MNIST_data/中读取MNIST数据。这条语句在数据不存在时,会自动执行下载
    mnist=input_data.read_data_sets("MNIST_data/", one_hot=True)

在执行语句mnist=input_data.read_data_sets("MNIST_data/", one_hot=True)时,TensorFlow会检测数据是否存在。当数据不存在时,系统会自动将数据下载到MNIST_data/文件夹中。当执行完语句后,读者可以自行前往MNIST_data/文件夹下查看上述4个文件是否已经被正确地下载若因网络问题无法正常下载,可以前往MNIST官网http://yann.lecun.com/exdb/mnist/使用下载工具下载上述4个文件,并将它们复制到MNIST_data/文件夹中。

成功加载MNIST数据集后,得到了一个mnist对象,可以通过mnist对象的属性访问到MNIST数据集,见表1-2。

表1-2 mnist对象中各个属性的含义和大小

运行下列代码可以查看各个变量的形状大小:

    # 查看训练数据的大小
    print(mnist.train.images.shape) # (55000, 784)
    print(mnist.train.labels.shape) # (55000, 10)

    # 查看验证数据的大小
    print(mnist.validation.images.shape) # (5000, 784)
    print(mnist.validation.labels.shape) # (5000, 10)

    # 查看测试数据的大小
    print(mnist.test.images.shape) # (10000, 784)
    print(mnist.test.labels.shape) # (10000, 10)

原始的MNIST数据集中包含了60000张训练图片和10000张测试图片。而在TensorFlow中,又将原先的60000张训练图片重新划分成了新的55000张训练图片和5000张验证图片。所以在mnist对象中,数据一共分为三部分:mnist.train是训练图片数据,mnist.validation是验证图片数据,mnist.test是测试图片数据,这正好对应了机器学习中的训练集、验证集和测试集。一般来说,会在训练集上训练模型,通过模型在验证集上的表现调整参数,最后通过测试集确定模型的性能。

1.1.2 实验:将MNIST数据集保存为图片

在原始的MNIST数据集中,每张图片都由一个28×28的矩阵表示,如图1-2所示。

图1-2 单张图片样本的矩阵表示

在TensorFlow中,变量mnist.train.images是训练样本,它的形状为(55000, 784)。其中,5000是训练图像的个数,而784实际为单个样本的维数,即每张图片都由一个784维的向量表示(784正好等于28×28)。可以使用以下代码打印出第0张训练图片对应的向量表示:

    # 打印出第0张图片的向量表示
    print(mnist.train.images[0, :])

为了加深对这种表示的理解,下面完成一个简单的程序:将MNIST数据集读取出来,并保存为图片文件。对应的代码文件为save_pic.py。

    #coding: utf-8
    from tensorflow.examples.tutorials.mnist import input_data
    import scipy.misc
    import os

    # 读取MNIST数据集。如果不存在会事先下载
    mnist=input_data.read_data_sets("MNIST_data/", one_hot=True)

    # 把原始图片保存在MNIST_data/raw/文件夹下
    # 如果没有这个文件夹,会自动创建
    save_dir='MNIST_data/raw/'
    if os.path.exists(save_dir) is False:
      os.makedirs(save_dir)

    # 保存前20张图片
    for i in range(20):
      # 请注意,mnist.train.images[i, :]就表示第i张图片(序号从0开始)
      image_array=mnist.train.images[i, :]
      # TensorFlow中的MNIST图片是一个784维的向量,我们重新把它还原为28×28维的图像
      image_array=image_array.reshape(28, 28)
      # 保存文件的格式为:
      # mnist_train_0.jpg, mnist_train_1.jpg, ... , mnist_train_19.jpg
      filename=save_dir+'mnist_train_%d.jpg' % i
      # 将image_array保存为图片
      # 先用scipy.misc.toimage转换为图像,再调用save直接保存
      scipy.misc.toimage(image_array, cmin=0.0, cmax=1.0).save(filename)

运行此程序后,在MNIST_data/raw/文件夹下就可以看到MNIST数据集中训练集的前20张图片。读者可以修改上述程序打印更多的图片。

1.1.3 图像标签的独热表示

变量mnist.train.labels表示训练图像的标签,它的形状是(55000, 10)。原始的图像标签是数字0~9,我们完全可以用一个数字来存储图像标签,但为什么这里每个训练标签是一个10维的向量呢?其实,这个10维的向量是原先类别号的独热(one-hot)表示。

所谓独热表示,就是“一位有效编码”。我们用N维的向量来表示N个类别,每个类别占据独立的一位,任何时候独热表示中只有一位是1,其他都为0。读者可以直接从表1-3中理解独热表示。

表1-3 类别的原始表示和独热表示

运行下面的代码可以打印出第0张训练图片的标签:

    # 打印出第0张训练图片的标签
    print(mnist.train.labels[0, :])

代码运行的结果是[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.],也就是说第0张图片对应的标签为数字“7”。

此外,我们可以打印出前20张图片的标签(对应程序label.py),读者可以尝试与第1.1.2节中保存的图片对照,查看图像与图像的标签是否正确地对应上了。

    # coding: utf-8
    from tensorflow.examples.tutorials.mnist import input_data
    import numpy as np
    # 读取MNIST数据集。如果不存在会事先下载
    mnist=input_data.read_data_sets("MNIST_data/", one_hot=True)
    # 看前20张训练图片的label
    for i in range(20):
      # 得到独热表示,形如(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)
      one_hot_label=mnist.train.labels[i, :]
      # 通过np.argmax,可以直接获得原始的label
      # 因为只有1位为1,其他都是0
      label=np.argmax(one_hot_label)
      print('mnist_train_%d.jpg label: %d' % (i, label))

至此,读者应当对变量mnist.train.images和mnist.train.labels很熟悉了。剩下的mnist.validation.images、mnist.validation.labels、mnist.test.images、mnist.test.labels四个变量与它们非常类似,唯一的区别只是图像的个数不同,本章就不再做更详细的解释了。