## 代码

Train.py

import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
import random
from tools.my_dataset import MyDataset
from tools.unet import UNet
from tools.set_seed import set_seed
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # 训练数据预处理
# train_transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     # 添加随机遮挡 旋转 等
#     transforms.ToTensor(),
#     transforms.Normalize(norm_mean, norm_std),
# ])
# # 验证数据预处理
# valid_transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.ToTensor(),
#     transforms.Normalize(norm_mean, norm_std),
# ])
#
# # 构建MyDataset实例
# train_data = MyDataset(data_dir=train_dir, transform=train_transform)
# valid_data = MyDataset(data_dir=valid_dir, transform=valid_transform)

set_seed()  # 设置随机种子

def compute_dice(y_pred, y_true):  # 计算dice系数
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))

if __name__ == "__main__":

# ============================ step 0/5 参数设置 ============================
LR = 0.01    # 学习率
BATCH_SIZE = 8  # 批大小
max_epoch = 1  # 训练epoch
start_epoch = 0  # 开始
lr_step = 150  # 调整学习率的间隔 ，每训练step_size个epoch，更新一次参数
val_interval = 3  # 验证间隔
checkpoint_interval = 20  # 模型保存间隔
vis_num = 10  # 可视化的间隔
# ============================ step 1/5 数据 ============================
# 读取数据文件夹
train_dir = os.path.join(BASE_DIR, "..", "data", "blood", "train")
valid_dir = os.path.join(BASE_DIR, "..", "data", "blood", "valid")

# 数据预处理
# train_transform = transforms.Compose([
#
#     # 添加随机遮挡 旋转 等
# ])
train_set = MyDataset(train_dir)  #,transform=train_transform
valid_set = MyDataset(valid_dir)

# 读取预处理后的数据

# ============================ step 2/5 模型 ============================
net = UNet(in_channels=3, out_channels=1, init_features=32)   # init_features is 64 in stander uent  输入是三通道的 输出是一通道的， init_features是第一个特征图的层数
net.to(device)

# ============================ step 3/5 损失函数 ============================
# 均方误差损失函数
loss_fn = nn.MSELoss()
# ============================ step 4/5 优化器 ============================

optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 定义优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1) # 调整学习率

# ============================ step 5/5 训练 ============================
train_curve = list()  # 画训练曲线
valid_curve = list()
train_dice_curve = list() # 画训练dice曲线
valid_dice_curve = list()
for epoch in range(start_epoch, max_epoch):

train_loss_total = 0.
train_dice_total = 0.

net.train()
for iter, (inputs, labels) in enumerate(train_loader):

if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)

# forward
outputs = net(inputs)

# backward
inputs.size()
# print(list(labels.size()))
# print(list(outputs.size()))
# print(list(inputs.size()))
#
# b = labels.numpy()
# b=b[0][0]
# b = np.around(b)
# #np.savetxt("test_inputs.csv", b)
#
# print(b)
#
# print("----------------------------------------------")
# print("----------------------------------------------")
# print("----------------------------------------------")
# print("----------------------------------------------")
# c = outputs.detach().numpy()
# c = c[0][0]
# c = np.around(c)
# #np.savetxt("test_outputs.csv", c)
# print(c)

loss = loss_fn(outputs, labels)
loss.backward()

optimizer.step()

# print
train_dice_curve.append(train_dice)
train_curve.append(loss.item())

train_loss_total += loss.item()

print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
"running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
train_loss_total/(iter+1), train_dice, scheduler.get_lr()))

scheduler.step()
# 保存模型
if (epoch + 1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)

# validate the model
if (epoch+1) % val_interval == 0:

net.eval()
valid_loss_total = 0.
valid_dice_total = 0.

for j, (inputs, labels) in enumerate(valid_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)

outputs = net(inputs)
loss = loss_fn(outputs, labels)

valid_loss_total += loss.item()

valid_dice_total += valid_dice

valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)

print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
epoch, max_epoch, valid_loss_mean, valid_dice_mean))

# 可视化
for idx, (inputs, labels) in enumerate(valid_loader):
if idx > vis_num:
break
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)

img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
plt.show()
plt.pause(0.5)
plt.close()

# plot curve
train_x = range(len(train_curve))
train_y = train_curve

valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss，需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()

# dice curve
train_x = range(len(train_dice_curve))
train_y = train_dice_curve

valid_x = np.arange(1, len(
valid_dice_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss，需要对记录点进行转换到iterations
valid_y = valid_dice_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('dice value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
torch.cuda.empty_cache()



Inference.py

import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
from tools.set_seed import set_seed
from tools.my_dataset import MyDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed()  # 设置随机种子

def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))

def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
img_names = list(filter(lambda x: x.endswith(format), file_names))
img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))

if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names

def get_model(m_path):

unet = UNet(in_channels=3, out_channels=1, init_features=32)

# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v

return unet

if __name__ == "__main__":

img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
model_path = "checkpoint_399_epoch.pkl"
time_total = 0
num_infer = 5

# 1. data
img_names = get_img_name(img_dir, format="png")
random.shuffle(img_names)
num_img = len(img_names)

# 2. model
unet = get_model(model_path)
unet.to(device)
unet.eval()

for idx, img_name in enumerate(img_names):
if idx > num_infer:
break

path_img = os.path.join(img_dir, img_name)
#
# step 1/4 : path --> img_chw
img_hwc = Image.open(path_img).convert('RGB')
img_hwc = img_hwc.resize((224, 224))
img_arr = np.array(img_hwc)
img_chw = img_arr.transpose((2, 0, 1))

# step 2/4 : img --> tensor
img_tensor = torch.tensor(img_chw).to(torch.float)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)

# step 3/4 : tensor --> features
time_tic = time.time()
outputs = unet(img_tensor)
time_toc = time.time()

# step 4/4 : visualization

img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
plt.show()
# plt.pause(0.5)
plt.close()

time_s = time_toc - time_tic
time_total += time_s

print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))



my_dataset.py

import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
random.seed(1)

class MyDataset(Dataset):
def __init__(self, data_dir, transform=None, in_size = 224):
super(MyDataset, self).__init__()
self.data_dir = data_dir
self.transform = transform
self.label_path_list = list()
self.in_size = in_size

self._get_img_path()

def __getitem__(self, index):

path_label = self.label_path_list[index]
path_img = path_label[:-9] + ".tif"
img_pil = Image.open(path_img).convert('RGB')
img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
# 在神经网络中，图像被表示成[c, h, w]格式或者[n, c, h, w]格式，但如果想要将图像以np.ndarray形式输入，因np.ndarray默认将图像表示成[h, w, c]个格式，需要对其进行转化。
img_hwc = np.array(img_pil)
#  print(img_hwc)
img_chw = img_hwc.transpose((2, 0, 1))
# 标签
label_pil = Image.open(path_label).convert('L')   # 灰度图，一通道
label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
label_hw = np.array(label_pil)
label_chw = label_hw[np.newaxis, :, :]
label_hw[label_hw != 0] = 1    # 变成二分类的标签

if self.transform is not None:
img_chw_tensor = torch.from_numpy(self.transform(img_chw.numpy())).float()
label_chw_tensor = torch.from_numpy(self.transform(label_chw.numpy())).float()
# print(type(img_chw))
# label_chw=Image.fromarray(label_chw)
# img_chw_tensor =self.transform(img_chw)
# label_chw_tensor=self.transform(label_chw)
else:
img_chw_tensor = torch.from_numpy(img_chw).float()
label_chw_tensor = torch.from_numpy(label_chw).float()
# img_chw=Image.fromarray(img_chw)
# label_chw=Image.fromarray(label_chw)
# img_chw_tensor =self.transform(img_chw)
# label_chw_tensor=self.transform(label_chw)

return img_chw_tensor, label_chw_tensor

def __len__(self):
return len(self.label_path_list)

def _get_img_path(self):
file_list = os.listdir(self.data_dir)
path_list = [os.path.join(self.data_dir, name) for name in file_list]
random.shuffle(path_list)
if len(path_list) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
self.label_path_list = path_list


set_seed.py

import random
import torch
import numpy as np

def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)



unet.py

from collections import OrderedDict

import torch
import torch.nn as nn

class UNet(nn.Module):

def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()

features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")

self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)

def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))

# bottleneck
bottleneck = self.bottleneck(self.pool4(enc4))

# 解码器
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)  # 那根线
dec4 = self.decoder4(dec4)

dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)

dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)

dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)

@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels, # 确定卷积核的深度
out_channels=features, # 确实输出的特征图深度，即卷积核组的多少
kernel_size=3,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,