将图片和标签转为可视化图片,查看图像增强的效果
import os
import random
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
# 51CTO 进阶的西红柿原创
# 定义反归一化函数,用于恢复原始图像
def denormalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.mul_(s).add_(m) # 乘标准差再加均值
return tensor
# 自定义数据集
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, image_paths, label_paths, mean, std, transform=None):
self.image_paths = image_paths
self.label_paths = label_paths
self.mean = mean
self.std = std
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像和标签
image = Image.open(self.image_paths[idx]).convert('RGB')
label = Image.open(self.label_paths[idx])
# 图像增强:随机水平翻转和垂直翻转
if random.random() > 0.5:
image = F.hflip(image)
label = F.hflip(label)
if random.random() > 0.5:
image = F.vflip(image)
label = F.vflip(label)
# 转换为张量并进行归一化
trans = T.Compose([
T.ToTensor(),
T.Normalize(mean=self.mean, std=self.std)
])
img = trans(image)
mask = torch.as_tensor(np.array(label), dtype=torch.int64)
return img, mask
# 均值和标准差,用于归一化和反归一化
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# 示例图像和标签路径
image_paths = [rf"E:\programs\dataloader\src\image\tu1.png", rf"E:\programs\dataloader\src\image\tu2.png"]
label_paths = [rf"E:\programs\dataloader\src\label\tu1.png", rf"E:\programs\dataloader\src\label\tu2.png"]
# 创建数据集实例
dataset = CustomDataset(image_paths, label_paths, mean, std)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
# 创建保存增强后图像和标签的文件夹
output_image_dir = rf"E:\programs\dataloader\new\IMAGE"
output_label_dir = rf"E:\programs\dataloader\new\LABEL"
os.makedirs(output_image_dir, exist_ok=True)
os.makedirs(output_label_dir, exist_ok=True)
# 逐个加载数据并保存增强后的图像和标签
for i, (image, label) in enumerate(dataloader):
# 反归一化图像
denorm_image = denormalize(image.clone().squeeze(0), mean, std)
# 保存增强后的图像
image_save_path = os.path.join(output_image_dir, f"augmented_image_{i}.png")
label_save_path = os.path.join(output_label_dir, f"augmented_label_{i}.png")
# 将反归一化后的图像保存为PNG格式
image_pil = T.ToPILImage()(denorm_image)
image_pil.save(image_save_path)
# 保存标签(假设标签是灰度图像)
label_np = label.squeeze(0).numpy().astype(np.uint8)
label_pil = Image.fromarray(label_np)
label_pil.save(label_save_path)
print(f"Saved augmented image to {image_save_path} and corresponding label to {label_save_path}")
51CTO 进阶的西红柿 原创