因为数据分类问题造成的。
在datasets.py中,这是将头部姿态角度进行分类,此处应该是分为了68类,但是在train_hopenet.py中是66类
labels = torch.LongTensor(np.digitize([yaw, pitch, roll], bins))
修改方法,在train_hopenet.py中,将66类数字修改为68即可。
还有生成filename_list的方法如下所示:
# coding:utf-8
import os
path = '/AFLW2000'
file = open('filename_lists.txt', 'w')
list = os.listdir(path)
for row in list:
if row.find('.jpg')>0:
file.write(row.replace('.jpg', '')+'\n')
file.close()
print(list)