Pytorch实现基于U-net的医学图像分割

目录结构

Pytorch实现基于U-net的医学图像分割_2d

代码

Train.py

import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
import random
from tools.my_dataset import MyDataset
from tools.unet import UNet
from tools.set_seed import set_seed
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # 训练数据预处理
# train_transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.RandomCrop(256, padding=4),
#     # 添加随机遮挡 旋转 等
#     transforms.ToTensor(),
#     transforms.Normalize(norm_mean, norm_std),
# ])
# # 验证数据预处理
# valid_transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.ToTensor(),
#     transforms.Normalize(norm_mean, norm_std),
# ])
#
# # 构建MyDataset实例
# train_data = MyDataset(data_dir=train_dir, transform=train_transform)
# valid_data = MyDataset(data_dir=valid_dir, transform=valid_transform)



set_seed()  # 设置随机种子

def compute_dice(y_pred, y_true):  # 计算dice系数
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


if __name__ == "__main__":

# ============================ step 0/5 参数设置 ============================
    LR = 0.01    # 学习率
    BATCH_SIZE = 8  # 批大小
    max_epoch = 1  # 训练epoch
    start_epoch = 0  # 开始
    lr_step = 150  # 调整学习率的间隔 ,每训练step_size个epoch,更新一次参数
    val_interval = 3  # 验证间隔
    checkpoint_interval = 20  # 模型保存间隔
    vis_num = 10  # 可视化的间隔
    mask_thres = 0.5  # 分类阈值
# ============================ step 1/5 数据 ============================
    # 读取数据文件夹
    train_dir = os.path.join(BASE_DIR, "..", "data", "blood", "train")
    valid_dir = os.path.join(BASE_DIR, "..", "data", "blood", "valid")

    # 数据预处理
    # train_transform = transforms.Compose([
    #
    #     transforms.RandomCrop(20, padding=4),
    #     # 添加随机遮挡 旋转 等
    # ])
    train_set = MyDataset(train_dir)  #,transform=train_transform
    valid_set = MyDataset(valid_dir)

    # 读取预处理后的数据
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

# ============================ step 2/5 模型 ============================
    net = UNet(in_channels=3, out_channels=1, init_features=32)   # init_features is 64 in stander uent  输入是三通道的 输出是一通道的, init_features是第一个特征图的层数
    net.to(device)

# ============================ step 3/5 损失函数 ============================
    # 均方误差损失函数
    loss_fn = nn.MSELoss()
# ============================ step 4/5 优化器 ============================

    optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 定义优化器
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1) # 调整学习率

# ============================ step 5/5 训练 ============================
    train_curve = list()  # 画训练曲线
    valid_curve = list()
    train_dice_curve = list() # 画训练dice曲线
    valid_dice_curve = list()
    for epoch in range(start_epoch, max_epoch):

        train_loss_total = 0.
        train_dice_total = 0.

        net.train()
        for iter, (inputs, labels) in enumerate(train_loader):

            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)

            # forward
            outputs = net(inputs)

            # backward
            optimizer.zero_grad()
            inputs.size()
            # print(list(labels.size()))
            # print(list(outputs.size()))
            # print(list(inputs.size()))
            #
            # b = labels.numpy()
            # b=b[0][0]
            # b = np.around(b)
            # #np.savetxt("test_inputs.csv", b)
            #
            # print(b)
            #
            # print("----------------------------------------------")
            # print("----------------------------------------------")
            # print("----------------------------------------------")
            # print("----------------------------------------------")
            # c = outputs.detach().numpy()
            # c = c[0][0]
            # c = np.around(c)
            # #np.savetxt("test_outputs.csv", c)
            # print(c)


            loss = loss_fn(outputs, labels)
            loss.backward()



            optimizer.step()

            # print
            train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu()) 
            train_dice_curve.append(train_dice)
            train_curve.append(loss.item())

            train_loss_total += loss.item()

            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
                  "running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
                                    train_loss_total/(iter+1), train_dice, scheduler.get_lr()))

        scheduler.step()
        # 保存模型
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint = {"model_state_dict": net.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
            torch.save(checkpoint, path_checkpoint)

        # validate the model
        if (epoch+1) % val_interval == 0:

            net.eval()
            valid_loss_total = 0.
            valid_dice_total = 0.

            with torch.no_grad():
                for j, (inputs, labels) in enumerate(valid_loader):
                    if torch.cuda.is_available():
                        inputs, labels = inputs.to(device), labels.to(device)

                    outputs = net(inputs)
                    loss = loss_fn(outputs, labels)

                    valid_loss_total += loss.item()

                    valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
                    valid_dice_total += valid_dice

                valid_loss_mean = valid_loss_total/len(valid_loader)
                valid_dice_mean = valid_dice_total/len(valid_loader)
                valid_curve.append(valid_loss_mean)
                valid_dice_curve.append(valid_dice_mean)

                print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
                    epoch, max_epoch, valid_loss_mean, valid_dice_mean))

    # 可视化
    with torch.no_grad():  # 不保留梯度,减少内存消耗,加快速度
        for idx, (inputs, labels) in enumerate(valid_loader):
            if idx > vis_num:
                break
            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            pred = outputs.ge(mask_thres)

            mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

            img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
            plt.subplot(121).imshow(img_hwc)
            mask_pred_gray = mask_pred.squeeze() * 255
            plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
            plt.show()
            plt.pause(0.5)
            plt.close()

    # plot curve
    train_x = range(len(train_curve))
    train_y = train_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
    valid_y = valid_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()

    # dice curve
    train_x = range(len(train_dice_curve))
    train_y = train_dice_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_dice_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
    valid_y = valid_dice_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('dice value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()
    torch.cuda.empty_cache()


Inference.py

import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
from tools.set_seed import set_seed
from tools.my_dataset import MyDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed()  # 设置随机种子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


def get_img_name(img_dir, format="jpg"):
    """
    获取文件夹下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))
    img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path):

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    checkpoint = torch.load(m_path, map_location="cpu")

    # remove module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    unet.load_state_dict(new_state_dict)

    return unet


if __name__ == "__main__":

    img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
    model_path = "checkpoint_399_epoch.pkl"
    time_total = 0
    num_infer = 5
    mask_thres = .5

    # 1. data
    img_names = get_img_name(img_dir, format="png")
    random.shuffle(img_names)
    num_img = len(img_names)

    # 2. model
    unet = get_model(model_path)
    unet.to(device)
    unet.eval()

    for idx, img_name in enumerate(img_names):
        if idx > num_infer:
            break

        path_img = os.path.join(img_dir, img_name)
        # path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
        #
        # step 1/4 : path --> img_chw
        img_hwc = Image.open(path_img).convert('RGB')
        img_hwc = img_hwc.resize((224, 224))
        img_arr = np.array(img_hwc)
        img_chw = img_arr.transpose((2, 0, 1))

        # step 2/4 : img --> tensor
        img_tensor = torch.tensor(img_chw).to(torch.float)
        img_tensor.unsqueeze_(0)
        img_tensor = img_tensor.to(device)

        # step 3/4 : tensor --> features
        time_tic = time.time()
        outputs = unet(img_tensor)
        time_toc = time.time()

        # step 4/4 : visualization
        pred = outputs.ge(mask_thres)
        mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

        img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
        plt.subplot(121).imshow(img_hwc)
        mask_pred_gray = mask_pred.squeeze() * 255
        plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
        plt.show()
        # plt.pause(0.5)
        plt.close()

        time_s = time_toc - time_tic
        time_total += time_s

        print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

my_dataset.py

import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
random.seed(1)



class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None, in_size = 224):
        super(MyDataset, self).__init__()
        self.data_dir = data_dir
        self.transform = transform
        self.label_path_list = list()
        self.in_size = in_size

        # 获取mask的path
        self._get_img_path()

    def __getitem__(self, index):

        path_label = self.label_path_list[index]
        path_img = path_label[:-9] + ".tif"
        img_pil = Image.open(path_img).convert('RGB')
        img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
        # 在神经网络中,图像被表示成[c, h, w]格式或者[n, c, h, w]格式,但如果想要将图像以np.ndarray形式输入,因np.ndarray默认将图像表示成[h, w, c]个格式,需要对其进行转化。
        img_hwc = np.array(img_pil)
        #  print(img_hwc)
        img_chw = img_hwc.transpose((2, 0, 1))
        # 标签
        label_pil = Image.open(path_label).convert('L')   # 灰度图,一通道
        label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
        label_hw = np.array(label_pil)
        label_chw = label_hw[np.newaxis, :, :]
        label_hw[label_hw != 0] = 1    # 变成二分类的标签

        if self.transform is not None:
            img_chw_tensor = torch.from_numpy(self.transform(img_chw.numpy())).float()
            label_chw_tensor = torch.from_numpy(self.transform(label_chw.numpy())).float()
            # print(type(img_chw))
            # label_chw=Image.fromarray(label_chw)
            # img_chw_tensor =self.transform(img_chw)
            # label_chw_tensor=self.transform(label_chw)
        else:
            img_chw_tensor = torch.from_numpy(img_chw).float()
            label_chw_tensor = torch.from_numpy(label_chw).float()
            # img_chw=Image.fromarray(img_chw)
            # label_chw=Image.fromarray(label_chw)
            # img_chw_tensor =self.transform(img_chw)
            # label_chw_tensor=self.transform(label_chw)

        return img_chw_tensor, label_chw_tensor

    def __len__(self):
        return len(self.label_path_list)

    def _get_img_path(self):
        file_list = os.listdir(self.data_dir)
        file_list = list(filter(lambda x: x.endswith("_mask.gif"), file_list))  # 尾缀是_matte.png是mask
        path_list = [os.path.join(self.data_dir, name) for name in file_list]
        random.shuffle(path_list)
        if len(path_list) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
        self.label_path_list = path_list

set_seed.py

import random
import torch
import numpy as np

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

unet.py

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        # 编码器
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

       # bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))

       # 解码器
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)  # 那根线
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels, # 确定卷积核的深度
                            out_channels=features, # 确实输出的特征图深度,即卷积核组的多少
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )