PyTorch入门实战教程笔记(三):手写数字问题引入

MNIST数据集

MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。

下载数据集

  网址:http://yann.lecun.com/exdb/mnist/   数据集简介:

  1、共有4数据集,下载之后保存在磁盘中(最好放在代码执行目录下,方便后期使用。)如新建一个文件夹D:*****\MNIST_data存放数据。  

pytorch输入csv pytorch输入手写数字_深度学习框架


  2、数据集中,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证,测试样本:共10000个

Input Or Output

输入:
  前面介绍了MNIST数据集中每个样本为28*28个像素,用矩阵表示即[28,28],将其打平为一个维度(后一行接在前一行的后面),即[784],也就是此时x是一个一维度的矩阵。那么这样子,输入x为[b,784],这里的b代表有多少证张图片,例如有一张,b=1.
输出:
  分类怎么表示呢,采用one-hot方式:把每一种具体的类别理解为一个具体的节点(node)的输出,如下:

dog = [1,0,0,...]
cat = [0,1,0,...]
fish = [0,0,1,...]

然后每一个node都有一个具体的实数值,然后想办法将该实数值的范围归一为(0,1)之间,也就是说,能够代表每一类的概率值。并且想办法将0-9这10个数字的输出的概率和总是为1。十个概率中肯定有一个最大概率,将最大的概率理解为当前类别的置信度,取最大概率所在的node的名字作为当前识别的种类的归类。

  将regression 写成矩阵形式,如下:(@表示两个矩阵相乘),假设输出的out中 P(y=1 | x)=0.8,那么如果是0-9这十个数字,就可以理解为输出的类别就是数字1.

pytorch输入csv pytorch输入手写数字_pytorch输入csv_02


  与简单的线性回归问题不同,手写数字高维图片线性模型不能完成手写数字识别的任务,所以做下列操作:在计算后加入一个relu激活函数,并且用三层来嵌套。

pytorch输入csv pytorch输入手写数字_数据集_03


  使用梯度下降的方法解决,目标:找到三组W,b的参数,使得W,b,x经过算后更接近与真实的y,使它们之间的误差越小越好。

pytorch输入csv pytorch输入手写数字_深度学习框架_04


  然后对于一个新的X,经过下面的计算,梯度下降的回归,最终输出下图绿色的矩阵,代表P(y=1 | x)=0.8,那么在经过argmax()函数,可以输出最终的结果就是1,即X就是数字1.

pytorch输入csv pytorch输入手写数字_pytorch输入csv_05