原文出处:https://blog.csdn.net/simple_the_best/article/details/75267863
import os import struct import numpy as np import matplotlib.pyplot as plt
1.下载数据
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本) Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签) Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本) Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
2.下载数据 定义函数读取数据
def load_mnist(path, kind='train'): """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)#注意文件名要和下载的一致 images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels#返回数组 (train_img,train_lab)=load_mnist('D:\minist\data',kind='train')
print(train_img.shape) print(train_lab.shape)
fig, ax = plt.subplots( nrows=2, ncols=5, sharex=True, sharey=True, )
ax = ax.flatten() for i in range(10): img = train_img[i].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()