PyTorch版本问题
作于2019.10.14
超分辨率的PyTorch实现,要求>=特定版本的PyTorch
本人在最近需要用到超分辨率算法,于是从GitHub上找了开源的项目。
但是本地部署之后发现,导入第三方库的时候有很多报错。
经查阅后,发现在PyTorch1.1.0之后,很多库弃用,或者是进行了整合修改(不在原位置)。这就导致了import报错。简单来说,就是有些时候,PyTorch版本之间不兼容(不向下兼容)的问题。
于是本人从PyTorch官网下载了当时最新的PyTorch1.2.0(中途还出现了几次意外情况,见另一篇博客)。
这样就解决了import报错,但是在实际运行的时候仍然报错,报错信息为ImportError: cannot import name ‘dataloader’
,但是torch.utils.data.dataloader是可以定位并且正常打开的(在PyCharm中按住ctrl再点击dataloader,可以打开dataloader.py文件)。
目前这个报错并没有解决,由于课内事务繁重,只能先暂时搁置这个问题。在这里挖个坑做个记录。
PyTorch更新之后导致旧项目报错
第一个报错
但是在更新PyTorch(1.0.1->1.2.0->1.1.0)之后,发现另一个项目出现了新的warning。
这个warning在1.2.0和1.1.0均存在。
简单来说应该就是定位到torch的utils。估计在实际运行的时候也会报错。(本人目前没有来得及尝试在我电脑的环境下是否会报错,现在这个程序是在学校的服务器上运行,所以报错也是我目前的猜测)。
但是有一点,这个问题并没有出现在其他的项目中。也就是说在其他项目中运行:
import torch
import torch.utils.data.dataloader
是不报错的。所以目前看来,这个warning好像只是存在于cloud_classification
这个项目中。
综上所述,我更怀疑是PyCharm抽风。
第二个报错
另外我又检查了另一个项目,cifar-10.py。发现在PyTorch 1.2.0下会报错。
报错信息如下:
Traceback (most recent call last):
File "E:\Anaconda\lib\site-packages\IPython\core\interactiveshell.py", line 3291, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-7623631534a6>", line 1, in <module>
runfile('C:/Users/73416/PycharmProjects/untitled4/cifar-10.py', wdir='C:/Users/73416/PycharmProjects/untitled4')
File "E:\PyCharm 2018.3.4\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "E:\PyCharm 2018.3.4\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:/Users/73416/PycharmProjects/untitled4/cifar-10.py", line 230, in <module>
for input_data, _ in train_loader:
File "E:\Anaconda\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
batch = self.collate_fn([self.dataset[i] for i in indices])
File "E:\Anaconda\lib\site-packages\torch\utils\data\dataloader.py", line 560, in <listcomp>
batch = self.collate_fn([self.dataset[i] for i in indices])
File "C:/Users/73416/PycharmProjects/untitled4/cifar-10.py", line 205, in __getitem__
target = self.target_transform(label)
File "C:/Users/73416/PycharmProjects/untitled4/cifar-10.py", line 164, in target_transform
target = torch.from_numpy(label).long() # 变为torch.LongTensor
TypeError: expected np.ndarray (got numpy.ndarray)
而且实话实说我并没有看懂这个报错。网上搜了一圈也没有找到和我报错信息相同的。
这个错误是出现在PyTorch 1.1.0下的,之后我把PyTorch 降级到了1.0.1(我之前一直在用的版本),发现报错消失,运行正常。
综上所述,这也是,PyTorch版本之间不兼容(不向下兼容)的问题。
查看PyTorch的更新日志
当然,理论上这些版本之间的不兼容理论上是可以预料的。预料的方式就是查阅PyTorch的更新日志。
查看方式如下:
- 进入PyTorch的GitHub。
- 点击release
- 找到对应的tag
即可获得PyTorch每个版本的更新日志。