One-Hot 编码
1. F.one_hot
pytorch 现在自带的将标签转成one-hot编码方法
import torch.nn.functional as F
import torch
x=torch.randint(low=0,high=3,size=(2,2))# 随机生成一张2*2的灰度图.一共3个类别数。所以0,1,2
print(x)
y=F.one_hot(x)# 如果不加类别数,会默认使用 输入数据中最大值,作为列别数。一般还是会加的
print(y.shape)
print(y)
# pytorch做模型训练时 中需要进行转置。有点麻烦
y=torch.from_numpy(y.numpy().transpose(2,0,1))
print(y)
结果如下
tensor([[2, 0],
[1, 0]])
torch.Size([2, 2, 3])
tensor([[[0, 0, 1],
[1, 0, 0]],
[[0, 1, 0],
[1, 0, 0]]])
tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 0]],
[[1, 0],
[0, 0]]])
也可以使用
y=F.one_hot(x,num_classes=3)
效果一样
加载图片
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
# 加载8位图。
x=Image.open(r"../input/lidcidri/LIDC/mask_roi/LIDC_Mask_0000.png").convert("P")
# x的数据是0和225.所以得需要把225 变成1 ,才能做ont-hot
x=np.array(x)
x[x==225]=1# 如果图片已经处理好了,就不需要转标签了,比如VOC数据集。或者自己制作了labelimg的生成的多少8位图
x=torch.from_numpy(x).long()
y=F.one_hot(x,num_classes=2)# 如果不加类别数,会默认使用 输入数据中最大值,作为列别数。一般还是会加的
print(y.shape)
# print(y)
# pytorch做模型训练时 中需要进行转置。有点麻烦
y=torch.from_numpy(y.numpy().transpose(2,0,1))
print(y.shape)
结果
torch.Size([64, 64, 2])
torch.Size([2, 64, 64])
2. torch.scatter_
源代码
import torch
def to_one_hot(mask, n_class):
"""
Transform a mask to one hot
change a mask to n * h* w n is the class
Args:
mask:
n_class: number of class for segmentation
Returns:
y_one_hot: one hot mask
"""
y_one_hot = torch.zeros((n_class, mask.shape[1], mask.shape[2]))
y_one_hot = y_one_hot.scatter(0, mask, 1).long()
return y_one_hot
# 这里一般使用 8位图加载图片信息。然后进行升维 。输入的时候是[1,高,宽]
# 返回结果就是 [num_class,H,W] [类别数,高,宽]
x=torch.randint(low=0,high=3,size=(1,2,2))
print(x)
print(x.shape)
y=to_one_hot(x,n_class=3)
print(y)
print(y.shape)
结果
tensor([[[0, 2],
[1, 0]]])
torch.Size([1, 2, 2])
tensor([[[1, 0],
[0, 1]],
[[0, 0],
[1, 0]],
[[0, 1],
[0, 0]]])
torch.Size([3, 2, 2])
加载图片实例
import torch
import numpy as np
from PIL import Image
def to_one_hot(mask, n_class):
"""
Transform a mask to one hot
change a mask to n * h* w n is the class
Args:
mask:
n_class: number of class for segmentation
Returns:
y_one_hot: one hot mask
"""
y_one_hot = torch.zeros((n_class, mask.shape[1], mask.shape[2]))
y_one_hot = y_one_hot.scatter(0, mask, 1).long()
return y_one_hot
# 这里一般使用 8位图加载图片信息。然后进行升维 。输入的时候是[1,高,宽]
# 返回结果就是 [num_class,H,W] [类别数,高,宽]
x=Image.open(r"../input/lidcidri/LIDC/mask_roi/LIDC_Mask_0000.png").convert("P")
# x的数据是0和225.所以得需要把225 变成1 ,才能做ont-hot
x=np.array(x)
x[x==225]=1
x=torch.from_numpy(x).unsqueeze(0).long() # 把x从numpy--->tensor
# x=torch.randint(low=0,high=3,size=(1,2,2))
print(x)
print(x.shape)
y=to_one_hot(x,n_class=2)
print(y)
print(y.shape)