mnist数据处理

mac2024-05-15  28

原文出处:https://blog.csdn.net/simple_the_best/article/details/75267863

 

 

import os import struct import numpy as np import matplotlib.pyplot as plt

1.下载数据

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本) Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签) Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本) Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)  

2.下载数据 定义函数读取数据

def load_mnist(path, kind='train'):     """Load MNIST data from `path`"""     labels_path = os.path.join(path,   '%s-labels.idx1-ubyte'     % kind)#注意文件名要和下载的一致     images_path = os.path.join(path,   '%s-images.idx3-ubyte'     % kind)     with open(labels_path, 'rb') as lbpath:         magic, n = struct.unpack('>II',   lbpath.read(8))         labels = np.fromfile(lbpath, dtype=np.uint8)     with open(images_path, 'rb') as imgpath:         magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))         images = np.fromfile(imgpath,  dtype=np.uint8).reshape(len(labels), 784)

    return images, labels#返回数组 (train_img,train_lab)=load_mnist('D:\minist\data',kind='train') 

print(train_img.shape) print(train_lab.shape)

fig, ax = plt.subplots(     nrows=2,     ncols=5,     sharex=True,     sharey=True, )

ax = ax.flatten() for i in range(10):     img = train_img[i].reshape(28, 28)     ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()  

最新回复(0)