目录
Fashion MNIST数据库
分类模型的建立
模型预测
总体代码
主要介绍基于tf.keras的Fashion MNIST数据库分类,
官方文档地址为:https://tensorflow.google.cn/tutorials/keras/basic_classification
文本分类类似,官网文档地址为https://tensorflow.google.cn/tutorials/keras/basic_text_classification
首先是函数的调用,对于tensorflow只有在版本1.2以上的版本才有tf.keras库。另外推荐使用python3,而不是python2。
Fashion MNIST数据库
fashion mnist数据库是mnist数据库的一个拓展。目的是取代mnist数据库,类似MINST数据库,fashion mnist数据库为训练集60000张,测试集10000张的28X28大小的服装彩色图片。具体分类如下:
标注编号 | 描述 |
0 | T-shirt/top(T恤) |
1 | Trouser(裤子) |
2 | Pullover(套衫) |
3 | Dress(裙子) |
4 | Coat(外套) |
5 | Sandal(凉鞋) |
6 | Shirt(汗衫) |
7 | Sneaker(运动鞋) |
8 | Bag(包) |
9 | Ankle boot(踝靴) |
样本描述如下:
名称 | 描述 | 样本数量 | 文件大小 | 链接 |
| 训练集的图像 | 60,000 | 26 MBytes | 下载 |
| 训练集的类别标签 | 60,000 | 29 KBytes | 下载 |
| 测试集的图像 | 10,000 | 4.3 MBytes | 下载 |
| 测试集的类别标签 | 10,000 | 5.1 KBytes | 下载 |
单张图像展示代码:
效果图:
样本的展示代码:
效果图:
分类模型的建立
检测模型输入数据为28X28,1个隐藏层节点数为128,输出类别10类,代码如下:
模型训练参数设置:
模型的训练:
模型预测
预测函数:
分类器是softmax分类器,输出的结果一个predictions是一个长度为10的数组,数组中每一个数字的值表示其所对应分类的概率值。如下所示:
对于predictions[0]其中第10个值最大,则该值对应的分类为class[9]ankle boot。
前25张图的分类效果展示:
效果图,绿色标签表示分类正确,红色标签表示分类错误:
对于单个图像的预测,需要将图像28X28的输入转换为1X28X28的输入,转换函数为np.expand_dims。函数使用如下:https://www.zhihu.com/question/265545749
总体代码