记录一下用改写VGG16,加上word2vec,在cifar10数据集上训练image2text的过程:
1.语言模型
选用已经训练好的text8,直接加载模型,对cifar10标签的各个单词进行词向量的映射即可。
2.image的模型
选用VGG16,对VGG16的后3层进行改写:
搭建自己的网络:my_vgg
然后加载VGG16训练好的参数矩阵VGG16.npy,初始化取前5层的参数后,对自己定义的my_vgg模型在新的数据集(cifar10)上重新进行训练。
加载npy文件中的预训练参数,且不要6,7,8三个全连接层,在自定义的网络上重新进行训练。训练生成一个输入图片,输出图像标签的模型,保存为model.ckpt.
3.joint_model的训练和预测
joint_model:
joint_model 需要完全复用my_vgg中定义的结构,计算出输入图片的one_hot标签,
同时在joint_model这个新的模型中,只取my_vgg模型的前几层,4096以后的几个全连接层都去掉,重新添加两个新的全连接层,
最终joint_model模型的输出与输入图片对应标签的词向量维度保持一致,此处为(1,200)
该模型的loss通过
计算。
最后模型测试效果见下:
准确率为: 0.8081
4.joint_model训练和预测过程中遇到的问题
训练过程中有时候训练到一半中断了,可以加载之前保存的模型,在该模型基础上训练,在这个joint_model中需要加载两个模型,
第一个模型是必须加载的模型,即在第二步里训练好的my_vgg.py保存的模型model.ckpt,从这个模型中计算输入图片,输出的one_hot。此时需要声明一个saver对象,加载它。
其次,还要声明一个saver2对象,来加载我们训练一半的joint_model中的模型。
这样,才可以恢复整个图中的参数。
进行预测的时候也是同样的道理,我在预测时犯了一个错误,刚开始只加载了joint_model训练后保存的模型,然而预测效果却很差,但是在训练的时候明明准确率很高了,后来检查了下,发现没有加载my_vgg对应的模型导致的。最后用两个saver对象分别加载这两个训练好的模型,问题就没有了。