目录
第1章 torchvision与预训练模型的自动下载
第1章 torchvision与预训练模型的自动下载
第2章 预训练模型的手工下载
第3章 网络介绍
第4章 前置条件:系统库的导入
import torch # torch基础库
import torchvision.models as models # torchvision模型库
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
第5章 预训练模型的导入
5.1 模型的创建
## 创建模型
net = models.resnet101()
print(net)
5.2 模型参数的导入
##导入模型参数
net_params_path = "models/resnet101.pth"
net_params = torch.load(model_params_path)
print(net_params)
5.3 模型参数的应用
# 把加载的参数应用到模型中
net.load_state_dict(net_params)
print(net)
5.4 模型的简单测试
(1)测试1
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print("input shape = ", input.shape)
output = net(input)
print("output shape = ", output.shape)
定义测试数据
input shape = torch.Size([1, 3, 224, 224])
output shape = torch.Size([1, 1000])
(2)测试2:
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print("input shape = ", input.shape)
output = net(input)
print("output shape = ", output.shape)
定义测试数据
input shape = torch.Size([1, 3, 224, 224])
output shape = torch.Size([1, 1000])
此时,可以使用该模型对图片进行预测了!!!