本文接上一篇文章“pytorch学习记录02——多分支网络”。把完整的代码展示出来供大家借鉴。
网络的结构图和里面的参数在上一篇文章已经说过了,这里就直接放代码了。
首先,导入相关的库
# 导入库
import random
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
然后,定义网络模型的结构
# 定义模型结构
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
# 3, 64, 64
self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 3, 64, 64
self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 3, 64, 64
self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 128, 4, 4
# 三个通道的channel合并
# 128*5, 4, 4
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
self.outlayer2 = nn.Linear(128 * 5, 256)
self.outlayer3 = nn.Linear(256, 5) # 是几分类第二个数就改成几,比如我们做5分类的任务,这里就是5
# 此处的输入为三个,对应三个分支
def forward(self, input1, input2, input3):
out1 = self.pooling1_1(self.conv1_1(input1))
out1 = self.pooling1_1(self.conv1_2(out1))
out1 = self.pooling1_1(self.conv1_3(out1))
out1 = self.pooling1_1(self.conv1_4(out1))
out2 = self.pooling2_1(self.conv2_1(input2))
out2 = self.pooling2_1(self.conv2_2(out2))
out2 = self.pooling2_1(self.conv2_3(out2))
out2 = self.pooling2_1(self.conv2_4(out2))
out3 = self.pooling3_1(self.conv3_1(input3))
out3 = self.pooling3_1(self.conv3_2(out3))
out3 = self.pooling3_1(self.conv3_3(out3))
out3 = self.pooling3_1(self.conv3_4(out3))
# 将三个分支的结果在channel维度上合并
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1) # [B, C, H, W] --> [B, C*H*W]
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
然后,定义好训练过程
def main():
device = DEVICE
model = ThreeInputsNet().to(device)
loss_fuc = nn.CrossEntropyLoss().to(device) # 设置损失函数为交叉熵函数
optimizer = optim.Adam(model.parameters(), lr=LR) # 设置优化器, weight_decay=0.1
# 训练模型
top_acc = 0.0
for epoch in range(EPOCHS):
train_loss = 0
train_acc = 0
model.train() # 声明为train模式
# 使用for循环同时遍历三个dataloader
for (num, (input_L, label_L)), (num1, (input_R, label_R)), (num2, (input_M, label_M)) in zip(
enumerate(train_loader_L, start=1),
enumerate(train_loader_R, start=1),
enumerate(train_loader_M, start=1)):
input_L, label_L = input_L.to(device), label_L.to(device)
input_R, label_R = input_R.to(device), label_R.to(device)
input_M, label_M = input_M.to(device), label_M.to(device)
# 比较三个label值是否相等(如果三个通道的输入数据要求有对应关系的话,就在这比较一下)
assert torch.equal(label_L, label_R) and torch.equal(label_R, label_M), "训练集标签不同"
y_ = model(input_L, input_R, input_M)
loss = loss_fuc(y_, label_M) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 损失回传
optimizer.step()
# 记录误差
train_loss += loss.item()
# 计算分类的准确率
out_t = y_.argmax(dim=1) # 取出预测的最大值
num_correct = (out_t == label_M).sum().item()
acc = num_correct / input_M.shape[0]
train_acc += acc
# 打印训练过程
rate = (num1 + 1) / len(train_loader_M)
a = "*" * int(rate * 50)
b = '.' * int((1 - rate) * 50)
print("\rtrain loss:{:^3.0f}%[{}->{}]{:.4f}".format(int(rate * 100), a, b, loss), end="")
print("\nEpoch:", epoch + 1, 'train_loss:', train_loss / len(train_loader_M), " train_acc:",
train_acc / len(train_loader_M))
# 测试模型
model.eval() # 声明为test模式
with torch.no_grad(): # with这一段不需要构建计算图
# test
total_correct = 0 # 正确的数量
total_num = 0
for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)) in zip(
enumerate(test_loader_L, start=1),
enumerate(test_loader_R, start=1),
enumerate(test_loader_M, start=1)):
x1, label1 = x1.to(device), label1.to(device)
x2, label2 = x2.to(device), label2.to(device)
x3, label3 = x3.to(device), label3.to(device)
# 判断label是否相等
assert torch.equal(label1, label2) and torch.equal(label2, label3), "测试标签不同"
y_ = model(x1, x2, x3)
pred = y_.argmax(dim=1) # 选出最大值的索引作为预测的分类结果
correct = torch.eq(pred, label1).float().sum().item() # 如果预测值和label值相等则正确数量加一
total_correct += correct
total_num += x1.size(0)
acc = total_correct / total_num
if acc >= top_acc:
top_acc = acc
print("Epoch:", epoch + 1, '; test_top_acc:', top_acc * 100.0, "; test_acc:", acc * 100.0)
最后,设置超参数并将数据送入网络
if __name__ == '__main__':
# 超参数
BATCH_SIZE = 8
EPOCHS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = 'cpu'
LR = 1e-3
print(DEVICE)
# 准备数据
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
# 读取数据
train_dataset_L = datasets.ImageFolder('自己的路径', train_transform)
train_dataset_R = datasets.ImageFolder('自己的路径', train_transform)
train_dataset_M = datasets.ImageFolder('自己的路径', train_transform)
test_dataset_L = datasets.ImageFolder('自己的路径', test_transform)
test_dataset_R = datasets.ImageFolder('自己的路径', test_transform)
test_dataset_M = datasets.ImageFolder('自己的路径', test_transform)
print(train_dataset_L.class_to_idx)
print(test_dataset_L.class_to_idx)
# 导入数据
seed = random.randint(0, 100) # 设置随机种子,用于打乱数据
g = torch.Generator()
g.manual_seed(seed) # 如果不加这句,每次启动程序后,随机的结果都是一样的
train_loader_L = torch.utils.data.DataLoader(train_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g)
g = torch.Generator()
g.manual_seed(seed)
train_loader_R = torch.utils.data.DataLoader(train_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g)
g = torch.Generator()
g.manual_seed(seed)
train_loader_M = torch.utils.data.DataLoader(train_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g)
# 对于训练集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
# for (num1, (input_L, label_L)), (num2, (input_R, label_R)), (num3, (input_M, label_M)),\
# in zip(enumerate(train_loader_L, start=1),
# enumerate(train_loader_R, start=1),
# enumerate(train_loader_M, start=1)):
#
# print(num1)
# print("label_L:{}".format(label_L))
# print("label_R:{}".format(label_R))
# print("label_M: {}".format(label_M))
# if num1 == 5:
# break
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_L = torch.utils.data.DataLoader(test_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_R = torch.utils.data.DataLoader(test_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_M = torch.utils.data.DataLoader(test_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
# 对于测试集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
# for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)),\
# in zip(enumerate(test_loader_L, start=1),
# enumerate(test_loader_R, start=1),
# enumerate(test_loader_M, start=1),):
# print(num1)
# print("test_l1:{}".format(label1))
# print("test_l2:{}".format(label2))
# print("test_l3:{}".format(label3))
# if num1 == 5:
# break
# 开始训练
main()
便于大家复制,这里给出整体的代码
# 导入库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
import random
# 定义模型结构
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
# 3, 64, 64
self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 3, 64, 64
self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 3, 64, 64
self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
# 128, 4, 4
# 三个通道的channel合并
# 128*5, 4, 4
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
self.outlayer2 = nn.Linear(128 * 5, 256)
self.outlayer3 = nn.Linear(256, 5) # 是几分类第二个数就改成几,比如我们做5分类的任务,这里就是5
# 此处的输入为三个,对应三个分支
def forward(self, input1, input2, input3):
out1 = self.pooling1_1(self.conv1_1(input1))
out1 = self.pooling1_1(self.conv1_2(out1))
out1 = self.pooling1_1(self.conv1_3(out1))
out1 = self.pooling1_1(self.conv1_4(out1))
out2 = self.pooling2_1(self.conv2_1(input2))
out2 = self.pooling2_1(self.conv2_2(out2))
out2 = self.pooling2_1(self.conv2_3(out2))
out2 = self.pooling2_1(self.conv2_4(out2))
out3 = self.pooling3_1(self.conv3_1(input3))
out3 = self.pooling3_1(self.conv3_2(out3))
out3 = self.pooling3_1(self.conv3_3(out3))
out3 = self.pooling3_1(self.conv3_4(out3))
# 将三个分支的结果在channel维度上合并
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1) # [B, C, H, W] --> [B, C*H*W]
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
def main():
device = DEVICE
model = ThreeInputsNet().to(device)
loss_fuc = nn.CrossEntropyLoss().to(device) # 设置损失函数为交叉熵函数
optimizer = optim.Adam(model.parameters(), lr=LR) # 设置优化器, weight_decay=0.1
# 训练模型
top_acc = 0.0
for epoch in range(EPOCHS):
train_loss = 0
train_acc = 0
model.train() # 声明为train模式
# 使用for循环同时遍历三个dataloader
for (num, (input_L, label_L)), (num1, (input_R, label_R)), (num2, (input_M, label_M)) in zip(
enumerate(train_loader_L, start=1),
enumerate(train_loader_R, start=1),
enumerate(train_loader_M, start=1)):
input_L, label_L = input_L.to(device), label_L.to(device)
input_R, label_R = input_R.to(device), label_R.to(device)
input_M, label_M = input_M.to(device), label_M.to(device)
# 比较三个label值是否相等(如果三个通道的输入数据要求有对应关系的话,就在这比较一下)
assert torch.equal(label_L, label_R) and torch.equal(label_R, label_M), "训练集标签不同"
y_ = model(input_L, input_R, input_M)
loss = loss_fuc(y_, label_M) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 损失回传
optimizer.step()
# 记录误差
train_loss += loss.item()
# 计算分类的准确率
out_t = y_.argmax(dim=1) # 取出预测的最大值
num_correct = (out_t == label_M).sum().item()
acc = num_correct / input_M.shape[0]
train_acc += acc
# 打印训练过程
rate = (num1 + 1) / len(train_loader_M)
a = "*" * int(rate * 50)
b = '.' * int((1 - rate) * 50)
print("\rtrain loss:{:^3.0f}%[{}->{}]{:.4f}".format(int(rate * 100), a, b, loss), end="")
print("\nEpoch:", epoch + 1, 'train_loss:', train_loss / len(train_loader_M), " train_acc:",
train_acc / len(train_loader_M))
# 测试模型
model.eval() # 声明为test模式
with torch.no_grad(): # with这一段不需要构建计算图
# test
total_correct = 0 # 正确的数量
total_num = 0
for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)) in zip(
enumerate(test_loader_L, start=1),
enumerate(test_loader_R, start=1),
enumerate(test_loader_M, start=1)):
x1, label1 = x1.to(device), label1.to(device)
x2, label2 = x2.to(device), label2.to(device)
x3, label3 = x3.to(device), label3.to(device)
# 判断label是否相等
assert torch.equal(label1, label2) and torch.equal(label2, label3), "测试标签不同"
y_ = model(x1, x2, x3)
pred = y_.argmax(dim=1) # 选出最大值的索引作为预测的分类结果
correct = torch.eq(pred, label1).float().sum().item() # 如果预测值和label值相等则正确数量加一
total_correct += correct
total_num += x1.size(0)
acc = total_correct / total_num
if acc >= top_acc:
top_acc = acc
print("Epoch:", epoch + 1, '; test_top_acc:', top_acc * 100.0, "; test_acc:", acc * 100.0)
if __name__ == '__main__':
# 超参数
BATCH_SIZE = 8
EPOCHS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = 'cpu'
LR = 1e-3
print(DEVICE)
# 准备数据
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
# 读取数据
train_dataset_L = datasets.ImageFolder('自己的路径', train_transform)
train_dataset_R = datasets.ImageFolder('自己的路径', train_transform)
train_dataset_M = datasets.ImageFolder('自己的路径', train_transform)
test_dataset_L = datasets.ImageFolder('自己的路径', test_transform)
test_dataset_R = datasets.ImageFolder('自己的路径', test_transform)
test_dataset_M = datasets.ImageFolder('自己的路径', test_transform)
print(train_dataset_L.class_to_idx)
print(test_dataset_L.class_to_idx)
# 导入数据
seed = random.randint(0, 100) # 设置随机种子,用于打乱数据
g = torch.Generator()
g.manual_seed(seed) # 如果不加这句,每次启动程序后,随机的结果都是一样的
train_loader_L = torch.utils.data.DataLoader(train_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g)
g = torch.Generator()
g.manual_seed(seed)
train_loader_R = torch.utils.data.DataLoader(train_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g)
g = torch.Generator()
g.manual_seed(seed)
train_loader_M = torch.utils.data.DataLoader(train_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g)
# 对于训练集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
# for (num1, (input_L, label_L)), (num2, (input_R, label_R)), (num3, (input_M, label_M)),\
# in zip(enumerate(train_loader_L, start=1),
# enumerate(train_loader_R, start=1),
# enumerate(train_loader_M, start=1)):
#
# print(num1)
# print("label_L:{}".format(label_L))
# print("label_R:{}".format(label_R))
# print("label_M: {}".format(label_M))
# if num1 == 5:
# break
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_L = torch.utils.data.DataLoader(test_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_R = torch.utils.data.DataLoader(test_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
g2 = torch.Generator()
g2.manual_seed(seed)
test_loader_M = torch.utils.data.DataLoader(test_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
# 对于测试集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
# for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)),\
# in zip(enumerate(test_loader_L, start=1),
# enumerate(test_loader_R, start=1),
# enumerate(test_loader_M, start=1),):
# print(num1)
# print("test_l1:{}".format(label1))
# print("test_l2:{}".format(label2))
# print("test_l3:{}".format(label3))
# if num1 == 5:
# break
# 开始训练
main()