卷积神经网络UNET学习
nn.Sequential
一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
Sequential传送门
Batchnorm
首先,此部分也即是讲为什么深度网络会需要 b a t c h n o r m batchnorm batchnorm,我们都知道,深度学习的话尤其是在CV上都需要对数据做归一化,因为深度神经网络主要就是为了学习训练数据的分布,并在测试集上达到很好的泛化效果,但是,如果我们每一个batch输入的数据都具有不同的分布,显然会给网络的训练带来困难。另一方面,数据经过一层层网络计算后,其数据分布也在发生着变化,此现象称为 I n t e r n a l Internal Internal C o v a r i a t e Covariate Covariate S h i f t Shift Shift,接下来会详细解释,会给下一层的网络学习带来困难。 b a t c h n o r m batchnorm batchnorm直译过来就是批规范化,就是为了解决这个分布变化问题。
Batchnorm传送门
Conv2d
in_channels:网络输入的通道数。
out_channels:网络输出的通道数。
kernel_size:卷积核的大小,如果该参数是一个整数q,那么卷积核的大小是qXq。
stride:步长。是卷积过程中移动的步长。默认情况下是1。一般卷积核在输入图像上的移动是自左至右,自上至下。如果参数是一个整数那么就默认在水平和垂直方向都是该整数。如果参数是stride=(2, 1),2代表着高(h)进行步长为2,1代表着宽(w)进行步长为1。
padding:填充,默认是0填充。
dilation:扩张。一般情况下,卷积核与输入图像对应的位置之间的计算是相同尺寸的,也就是说卷积核的大小是3X3,那么它在输入图像上每次作用的区域是3X3,这种情况下dilation=0。
Conv2d传送门
ModuleList
nn.ModuleList 这个类,你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中
ModuleList传送门
MaxPool2d
kernel_size(int or tuple) - max pooling的窗口大小
stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size
padding(int or tuple, optional) - 输入的每一条边补充0的层数
dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数
return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助
ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作
tqdm
tqdm是 python的一个关于进度条的扩展包,在深度学习进程中可以将训练过程用进度条的形式展现出来,会让训练界面更加的美观。下面介绍一下常见的tqdm的函数以及使用方式。
参考自官方文档,tqdm的常见参数有:
desc(‘str’): 传入进度条的前缀
mininterval(float):最小的更新时间 [default: 0.1] seconds
maxinterval(float):最大的更新时间 [default: 10] seconds. 只有在dynamic_miniters
miniters(int or float):最小的展示更新进度,如果设置为0或者是dynamic_miniters程序会自动的调整去让miniterval与它项适应
ascii(bool or str):如果调整为True的话会使用ASCII(美国信息交换标准代码)码,默认为False会使用unicode
ncols(int):整个输出信息的宽度
nrows(int):进度条的高速
dynamic_ncols(bool):会在环境中持续改变ncols和nrows
smoothing(float):会平均移动因素和预计的时间
bar_format(str):可以自己定义一个
position(int):设置打印进度条的位置,可以设置多个bar
colour(str):进度条的颜色
set_postfix : 设置信息
model.py
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
class DoubleConv(nn.Module):
def __init__(self,in_channels,out_channels):
super(DoubleConv,self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1 , 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.conv(x)
class UNET(nn.Module):
def __init__(
self,in_channels=3,out_channels=1,features=[64,128,256,512],
):
super(UNET,self).__init__()
self.ups=nn.ModuleList()
self.downs=nn.ModuleList()
self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
#Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels=feature
# UP part of UNET
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature*2,feature,kernel_size=2,stride=2,
)
)
self.ups.append(DoubleConv(feature*2,feature))
self.bottleneck=DoubleConv(features[-1],features[-1]*2)
self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)
def forward(self,x):
skip_connections=[]
for down in self.downs:
x=down(x)
skip_connections.append(x)
x=self.pool(x)
x=self.bottleneck(x)
skip_connections=skip_connections[::-1]
for idx in range(0,len(self.ups),2):
x=self.ups[idx](x)
skip_connection=skip_connections[idx//2]
if x.shape != skip_connection.shape:
x=TF.resize(x,size=skip_connection.shape[2:])
concat_skip=torch.cat((skip_connection,x),dim=1)
x=self.ups[idx+1](concat_skip)
return self.final_conv(x)
def test():
x=torch.randn((3,1,512,384))
print(x.shape)
model=UNET(in_channels=1,out_channels=1)
preds=model(x)
print(preds.shape)
assert preds.shape==x.shape
if __name__ == '__main__':
test()
train.py
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
load_checkpoint,
save_checkpoint,
get_loaders,
check_accuracy,
save_predictions_as_imgs,
)
# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 15
NUM_WORKERS = 2
IMAGE_WIDTH = 512 # 1918 originally
IMAGE_HEIGHT = 384 # 1280 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "mydata/data/train_images/"
TRAIN_MASK_DIR = "mydata/data/train_masks/"
VAL_IMG_DIR = "mydata/data/val_images/"
VAL_MASK_DIR = "mydata/data/val_masks/"
def train_fn(loader, model, optimizer, loss_fn, scaler):
loop = tqdm(loader)
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device=DEVICE)
targets = targets.float().unsqueeze(1).to(device=DEVICE)
# forward
with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update tqdm loop
loop.set_postfix(loss=loss.item())
def main():
train_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Rotate(limit=35, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
train_loader, val_loader = get_loaders(
TRAIN_IMG_DIR,
TRAIN_MASK_DIR,
VAL_IMG_DIR,
VAL_MASK_DIR,
BATCH_SIZE,
train_transform,
val_transforms,
NUM_WORKERS,
PIN_MEMORY,
)
if LOAD_MODEL:
load_checkpoint(torch.load("mydata/my_checkpoint.pth.tar"), model)
check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()
max_score = 0
for epoch in range(NUM_EPOCHS):
train_fn(train_loader, model, optimizer, loss_fn, scaler)
# check accuracy
Dice_score=check_accuracy(val_loader, model, device=DEVICE)
if Dice_score>max_score:
max_score=Dice_score
# save model
checkpoint = {
"state_dict": model.state_dict(),
# "optimizer":optimizer.state_dict(),
}
save_checkpoint(checkpoint)
# print some examples to a folder
save_predictions_as_imgs(
val_loader, model, folder="mydata/saved_images/", device=DEVICE
)
if __name__ == "__main__":
man()
utils.py
import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
def save_checkpoint(state, filename="mydata/my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
def get_loaders(
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers=4,
pin_memory=True,
):
train_ds = CarvanaDataset(
image_dir=train_dir,
mask_dir=train_maskdir,
transform=train_transform,
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=True,
)
val_ds = CarvanaDataset(
image_dir=val_dir,
mask_dir=val_maskdir,
transform=val_transform,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
)
return train_loader, val_loader
def check_accuracy(loader, model, device="cuda"):
num_correct = 0
num_pixels = 0
dice_score = 0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
num_pixels += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / (
(preds + y).sum() + 1e-8
)
Dice_score=dice_score/len(loader)
print(
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
)
print(f"Dice score: {Dice_score}")
model.train()
return Dice_score
def save_predictions_as_imgs(
loader, model, folder="mydata/saved_images/", device="cuda"
):
model.eval()
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
torchvision.utils.save_image(
preds, f"{folder}/pred_{idx}.png"
)
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
model.train()
dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
class CarvanaDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
mask_path = os.path.join(self.mask_dir, self.images[index])
# mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
image = np.array(Image.open(img_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
mask[mask != 0.0] = 1.0
if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask
提示:如果要训练自己的代码请下载kaggle数据集