# 语义分割多分类的loss 计算和one-hot 编码
# 本文验证了语义分割任务下,单通道输出和多通道输出时,使用交叉熵计算损失值的细节问题。
# 对比验证了使用简单的函数和自带损失函数的结果,通过验证,进一步加强了对交叉熵的理解。
import torch
import torch.nn as nn
import torch.nn.functional as F

# 首先,假设我们研究的是一个二分类语义分割问题。
# 网络的输入是一个 2×2 的图像,设置 batch_size 为 2,网络输出单通道特征图。
# 网络的标签也是一个 2×2 的二进制掩模图(即只有0和1的单通道图像)。
# 假设pred 一个 [batch_size=2, channel=1, height=2, width=2] 格式的张量 x1

x1 = torch.tensor([[[[0.43, -0.25],
                     [-0.32, 0.69]]],
                     [[[-0.29, 0.37],
                     [0.54,  -0.72]]]])
# 假设标签图像为与 x1 同型的张量 y1
y1 = torch.tensor(
    [[[[0., 0.],
    [0., 1.]]],
    [[[0., 0.],
    [1.,  1.]]]])
# print("x1:",x1.shape) # [2, 1, 2, 2]
# print("y1:",y1.shape) # [2, 1, 2, 2]
#
# BCEloss = torch.nn.BCEWithLogitsLoss()
# #二进制的 sigmoid + 交叉熵 + 均值
# bceloss = BCEloss(x1,y1)
# print(bceloss.item())

# ----------------单张 图片的预测值的处理--------------------
# 单张图像的预测,要对它利用unsqueeze增加一个batch维度,送入网络预测出
# 结果后要利用squeeze去除batch维度。
# x = torch.rand(2,1,2,2).cuda()
# x = x[0]
# x = torch.unsqueeze(x,dim=0) # 去除bs
# pr = F.softmax(x.permute(1,2,0),dim =-1).cpu().numpy() # 取出最大值索引(对通道进行索引)

# 多通道输出时的交叉熵损失计算
# 网络的输入是一个 2×2 的图像,设置 batch_size 为 2,网络输出多(二)通道特征图。网络的标签也是一个 2 ×2 的二进制掩模图(即只有0和1的单通道图像)。
# 假设输出一个[batch_size=2, channel=2, height=2, width=2]格式的张量 x1
x1 = torch.tensor([[[[ 0.3164, -0.1922],
          [ 0.4326, -1.2193]],

         [[ 0.6873,  0.6838],
          [ 0.2244,  0.5615]]],


        [[[-0.2516, -0.8875],
          [-0.6289, -0.1796]],

         [[ 0.0411, -1.7851],
          [-0.3069, -1.0379]]]])

# 假设标签图像为与x1同型,然后去掉channel的张量 y1 (注意两点,channel没了,格式为LongTensor)
y1 = torch.LongTensor([[[0., 1.],
                        [1., 0.]],
                        [[1., 1.],
                        [0., 1.]]])

print("x1:",x1.shape) # [2, 2, 2, 2]
print("y1:",y1.shape) # [2, 2, 2] 灰度 (n 分类的灰度图)
CrossE = nn.CrossEntropyLoss()  #
ce_loss = CrossE(x1,y1)
print("ce_loss:",ce_loss.item())

# 等价

s1 = torch.softmax(x1,dim=1) # 预测值先进行softmax
print("s1:",s1.shape)
# 标签进行one_hot (标签是灰度时)
y1_one_hot = torch.zeros_like(x1).scatter_(dim=1,index=y1.unsqueeze(dim=1),src=torch.ones_like(x1))
print("y1_one_hot:",y1_one_hot.shape) # [2, 2, 2, 2]

# 进行loss 交叉熵计算 手动计算
loss_cal = -1 *(y1_one_hot * torch.log(s1))
loss_cal_mean = loss_cal.sum(dim=1).mean() # 在batch维度下计算每个样本的交叉熵
print("loss_cal_mean:",loss_cal_mean.item())

# 通过softmax和argmax取通道最大的索引之后,我们获得的预测结果就是和输入图像一样大小的预测图,
# 图中的每个位置的像素点的值是网络对该像素点预测的类别,其数值为0-N(N为网络需要预测的类别),
# 这些数值人眼是无法进行观察的,所以需要对它进行染色,转换成便于观察的彩色图像。对于将预测结果进行后处理的代码主要如下:
import numpy as np
num_classes = 2
pr = y1[0]
if num_classes <= 21:
    colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
                   (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
                   (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
                   (128, 64, 12)]

seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
for c in range(num_classes):
    seg_img[:, :, 0] += ((pr[:, :] == c ) * colors[c][0]).astype('uint8')
    seg_img[:, :, 1] += ((pr[:, :] == c ) * colors[c][1]).astype('uint8')
    seg_img[:, :, 2] += ((pr[:, :] == c ) * colors[c][2]).astype('uint8')