1、下载mnist数据集
地址:http://yann.lecun.com/exdb/mnist/
下面这四个都要下载,下载完成后,解压到同一个目录,我是解压到“E:/fashion_mnist/”这个目录里面,好和下面的代码目录一致
解压完成后,需要修改一下文件名,如(修改原因:保持和下面代码一样,避免出现其它问题):
修改前:t10k-images.idx3-ubyte
修改后:t10k-images-idx3-ubyte
我是第一次弄这玩意,所以尽量弄得白痴些,走弯路很烦,有时候一点点小问题就弄半天,其实就是别人有那么一点没讲清楚,然后就会搞很久
2、执行原文1里面的这段代码。
这段代码里面,需要先用pip安装skimage、torch、torchvision,前两篇文章有安装步骤。
importosfrom skimage importioimporttorchvision.datasets.mnist as mnist
root="E:/fashion_mnist/"train_set=(
mnist.read_image_file(os.path.join(root,‘train-images-idx3-ubyte‘)),
mnist.read_label_file(os.path.join(root,‘train-labels-idx1-ubyte‘))
)
test_set=(
mnist.read_image_file(os.path.join(root,‘t10k-images-idx3-ubyte‘)),
mnist.read_label_file(os.path.join(root,‘t10k-labels-idx1-ubyte‘))
)print("training set :",train_set[0].size())print("test set :",test_set[0].size())def convert_to_img(train=True):if(train):
f=open(root+‘train.txt‘,‘w‘)
data_path=root+‘/train/‘
if(notos.path.exists(data_path)):
os.makedirs(data_path)for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
img_path=data_path+str(i)+‘.jpg‘io.imsave(img_path,img.numpy())
f.write(img_path+‘ ‘+str(label)+‘\n‘)
f.close()else:
f= open(root + ‘test.txt‘, ‘w‘)
data_path= root + ‘/test/‘
if (notos.path.exists(data_path)):
os.makedirs(data_path)for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
img_path= data_path+ str(i) + ‘.jpg‘io.imsave(img_path, img.numpy())
f.write(img_path+ ‘ ‘ + str(label) + ‘\n‘)
f.close()
convert_to_img(True)
convert_to_img(False)