MNIST

在本章当中,我们将会使用 MNIST 这个数据集,它有着 70000 张规格较小的手写数字图片,由美国的高中生和美国人口调查局的职员手写而成。这相当于机器学习当中的“Hello World”,人们无论什么时候提出一个新的分类算法,都想知道该算法在这个数据集上的表现如何。机器学习的初学者迟早也会处理 MNIST 这个数据集。

Scikit-Learn 提供了许多辅助函数,以便于下载流行的数据集。MNIST 是其中一个。下面的代码获取 MNIST

  1. >>> from sklearn.datasets import fetch_mldata
  2. >>> mnist = fetch_mldata('MNIST original')
  3. >>> mnist
  4. {'COL_NAMES': ['label', 'data'],
  5. 'DESCR': 'mldata.org dataset: mnist-original',
  6. 'data': array([[0, 0, 0, ..., 0, 0, 0],
  7. [0, 0, 0, ..., 0, 0, 0],
  8. [0, 0, 0, ..., 0, 0, 0],
  9. ...,
  10. [0, 0, 0, ..., 0, 0, 0],
  11. [0, 0, 0, ..., 0, 0, 0],
  12. [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
  13. 'target': array([ 0., 0., 0., ..., 9., 9., 9.])}

一般而言,由 sklearn 加载的数据集有着相似的字典结构,这包括:

  • DESCR键描述数据集
  • data键存放一个数组,数组的一行表示一个样例,一列表示一个特征
  • target键存放一个标签数组

让我们看一下这些数组

  1. >>> X, y = mnist["data"], mnist["target"]
  2. >>> X.shape
  3. (70000, 784)
  4. >>> y.shape
  5. (70000,)

MNIST 有 70000 张图片,每张图片有 784 个特征。这是因为每个图片都是28*28像素的,并且每个像素的值介于 0~255 之间。让我们看一看数据集的某一个数字。你只需要将某个实例的特征向量,reshape28*28的数组,然后使用 Matplotlib 的imshow函数展示出来。

  1. %matplotlib inline
  2. import matplotlib
  3. import matplotlib.pyplot as plt
  4. some_digit = X[36000]
  5. some_digit_image = some_digit.reshape(28, 28)
  6. plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest")
  7. plt.axis("off")
  8. plt.show()

5

这看起来像个 5,实际上它的标签告诉我们:

  1. >>> y[36000]
  2. 5.0

图3-1 展示了一些来自 MNIST 数据集的图片。当你处理更加复杂的分类任务的时候,它会让你更有感觉。

图3-1

先等一下!你总是应该先创建测试集,并且在验证数据之前先把测试集晾到一边。MNIST 数据集已经事先被分成了一个训练集(前 60000 张图片)和一个测试集(最后 10000 张图片)

  1. X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

让我们打乱训练集。这可以保证交叉验证的每一折都是相似(你不会期待某一折缺少某类数字)。而且,一些学习算法对训练样例的顺序敏感,当它们在一行当中得到许多相似的样例,这些算法将会表现得非常差。打乱数据集将保证这种情况不会发生。

  1. import numpy as np
  2. shuffle_index = np.random.permutation(60000)
  3. X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]