将图片和标签转为可视化图片,查看图像增强的效果

import os
import random
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np

# 51CTO 进阶的西红柿原创
# 定义反归一化函数,用于恢复原始图像
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)  # 乘标准差再加均值
    return tensor

# 自定义数据集
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, label_paths, mean, std, transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.mean = mean
        self.std = std
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 加载图像和标签
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = Image.open(self.label_paths[idx])

        # 图像增强:随机水平翻转和垂直翻转
        if random.random() > 0.5:
            image = F.hflip(image)
            label = F.hflip(label)

        if random.random() > 0.5:
            image = F.vflip(image)
            label = F.vflip(label)

        # 转换为张量并进行归一化
        trans = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=self.mean, std=self.std)
        ])

        img = trans(image)
        mask = torch.as_tensor(np.array(label), dtype=torch.int64)

        return img, mask

# 均值和标准差,用于归一化和反归一化
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# 示例图像和标签路径
image_paths = [rf"E:\programs\dataloader\src\image\tu1.png", rf"E:\programs\dataloader\src\image\tu2.png"]
label_paths = [rf"E:\programs\dataloader\src\label\tu1.png", rf"E:\programs\dataloader\src\label\tu2.png"]

# 创建数据集实例
dataset = CustomDataset(image_paths, label_paths, mean, std)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# 创建保存增强后图像和标签的文件夹
output_image_dir = rf"E:\programs\dataloader\new\IMAGE"
output_label_dir = rf"E:\programs\dataloader\new\LABEL"
os.makedirs(output_image_dir, exist_ok=True)
os.makedirs(output_label_dir, exist_ok=True)

# 逐个加载数据并保存增强后的图像和标签
for i, (image, label) in enumerate(dataloader):
    # 反归一化图像
    denorm_image = denormalize(image.clone().squeeze(0), mean, std)

    # 保存增强后的图像
    image_save_path = os.path.join(output_image_dir, f"augmented_image_{i}.png")
    label_save_path = os.path.join(output_label_dir, f"augmented_label_{i}.png")

    # 将反归一化后的图像保存为PNG格式
    image_pil = T.ToPILImage()(denorm_image)
    image_pil.save(image_save_path)

    # 保存标签(假设标签是灰度图像)
    label_np = label.squeeze(0).numpy().astype(np.uint8)
    label_pil = Image.fromarray(label_np)
    label_pil.save(label_save_path)

    print(f"Saved augmented image to {image_save_path} and corresponding label to {label_save_path}")

51CTO 进阶的西红柿 原创