Learning Continuous Image Representation with Local Implicit Image Function

  • abstract
  • Local Implicit Image Function
  • Feature unfolding
  • Local ensemble
  • Cell decoding
  • LIIF class 完全代码




abstract

物理世界以连续的方式呈现视觉图像,但计算机以离散2D像素数组的方式存储和显示图像。此文学习图像的连续表示,使用局部隐式图像函数(Local Implicit Image Function,LIIF)将图像坐标和坐标周围的2D深度特征作为输入,预测输出给定坐标下的RGB值。通过自监督超分辨率任务来训练一个编码器和LIIF表示来生成像素图像的连续表示,可以做到任意倍数的分辨率,甚至可以推算不在训练任务中的30倍以上超分。通过将图像模型化为一个在连续域中的函数,可以恢复和生成任意分辨率的图像。隐式函数的思想是将一个对象表示为一个函数,将坐标映射到相应的信号(如3D对象表面的符号距离,图像中的RGB值)。神经隐式函数采用深度神经网络参数化。为了跨实例共享知识,而不是为每个对象拟合单独的隐式函数,提出了基于编码器的方法来预测每个对象的潜在编码。然后隐式函数由所有对象共享,同时它将潜在代码作为额外的输入。

传统图像超分 python实现 图像超分辨率代码_Image

Local Implicit Image Function

在LIIF表示中,每个连续图像传统图像超分 python实现 图像超分辨率代码_计算机视觉_02由二维特征映射传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_03表示。 一个神经隐式函数传统图像超分 python实现 图像超分辨率代码_Image_04(以传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_05为其参数)被所有图像共享,它被参数化为传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_06并采取传统图像超分 python实现 图像超分辨率代码_深度学习_07(简便省略传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_05)形式,其中传统图像超分 python实现 图像超分辨率代码_深度学习_09是一个向量,传统图像超分 python实现 图像超分辨率代码_计算机视觉_10是连续图像域中的二维坐标,传统图像超分 python实现 图像超分辨率代码_深度学习_11是预测信号(即RGB值)。

对于定义的传统图像超分 python实现 图像超分辨率代码_ci_12,每个向量传统图像超分 python实现 图像超分辨率代码_深度学习_09都可以看作是表示函数传统图像超分 python实现 图像超分辨率代码_ci_14传统图像超分 python实现 图像超分辨率代码_ci_15可以看作是一个连续的图像,即映射坐标到RGB值的函数。假设传统图像超分 python实现 图像超分辨率代码_深度学习_16传统图像超分 python实现 图像超分辨率代码_计算机视觉_17特征向量(称为隐码latent codes)均匀分布在传统图像超分 python实现 图像超分辨率代码_计算机视觉_02的连续图像域的2D空间中,并为它们中的每一个分配一个2D坐标。

对于图像传统图像超分 python实现 图像超分辨率代码_计算机视觉_02,坐标传统图像超分 python实现 图像超分辨率代码_Image_20处的RGB值定义为传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_21,其中传统图像超分 python实现 图像超分辨率代码_ci_22传统图像超分 python实现 图像超分辨率代码_深度学习_16中与传统图像超分 python实现 图像超分辨率代码_Image_20最近的(欧几里德距离)隐码,传统图像超分 python实现 图像超分辨率代码_Image_25是图像域中潜码传统图像超分 python实现 图像超分辨率代码_ci_22的坐标。 例如传统图像超分 python实现 图像超分辨率代码_深度学习_27是当前定义中传统图像超分 python实现 图像超分辨率代码_Image_20传统图像超分 python实现 图像超分辨率代码_ci_22,而传统图像超分 python实现 图像超分辨率代码_Image_25被定义为传统图像超分 python实现 图像超分辨率代码_深度学习_27的坐标。

传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_32


在所有图像共享的隐式函数传统图像超分 python实现 图像超分辨率代码_ci_12下,连续图像由二维特征映射传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_34表示,该特征映射被看作是在2D域中均匀分布的传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_35隐码。 在传统图像超分 python实现 图像超分辨率代码_深度学习_16中的每个潜在码传统图像超分 python实现 图像超分辨率代码_深度学习_09表示连续图像的局部部分,负责预测与它最近的坐标集的信号。

从图像得到归一化坐标值和RGB值

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret
    
coord = make_coord((h, w)) #h,w为SR目标的高宽

def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])   #(h*w,2)--(h*w,[x,y])
    rgb = img.view(3, -1).permute(1, 0)  #(h*w,3)--(h*w,[R,G,B])
    return coord, rgb

Feature unfolding

为了丰富隐码包含的信息,对特征传统图像超分 python实现 图像超分辨率代码_深度学习_16展开得到传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_39传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_39是在传统图像超分 python实现 图像超分辨率代码_深度学习_16传统图像超分 python实现 图像超分辨率代码_Image_42相邻隐码的合并。

传统图像超分 python实现 图像超分辨率代码_ci_43指的是一组向量的连接时,传统图像超分 python实现 图像超分辨率代码_深度学习_16在其边界外被零向量填充。
传统图像超分 python实现 图像超分辨率代码_计算机视觉_45传统图像超分 python实现 图像超分辨率代码_深度学习_46被以下传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_47后变为传统图像超分 python实现 图像超分辨率代码_深度学习_48

feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

Local ensemble

传统图像超分 python实现 图像超分辨率代码_深度学习_07是一个不连续预测,由于传统图像超分 python实现 图像超分辨率代码_Image_20的信号预测是通过查询传统图像超分 python实现 图像超分辨率代码_深度学习_16中最近的隐码传统图像超分 python实现 图像超分辨率代码_ci_22完成的,所以当传统图像超分 python实现 图像超分辨率代码_Image_20在图像域中移动时,传统图像超分 python实现 图像超分辨率代码_ci_22会突然从一个隐码切换到另一个隐码。在传统图像超分 python实现 图像超分辨率代码_ci_22选择切换的那些坐标周围,两个无限接近坐标的信号将从不同的隐码中预测出来,只要学习的隐式函数传统图像超分 python实现 图像超分辨率代码_ci_12不是完美的,在传统图像超分 python实现 图像超分辨率代码_ci_22选择切换的边界处没出现不连续的图形。为了解决这个问题,使用局部集成技术,扩大每个隐码的表示
传统图像超分 python实现 图像超分辨率代码_Image_58
传统图像超分 python实现 图像超分辨率代码_Image_59指左上、右上,左下,右下子空间中最近的隐码,传统图像超分 python实现 图像超分辨率代码_Image_60传统图像超分 python实现 图像超分辨率代码_计算机视觉_61的坐标,传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_62传统图像超分 python实现 图像超分辨率代码_Image_20传统图像超分 python实现 图像超分辨率代码_Image_64传统图像超分 python实现 图像超分辨率代码_Image_64传统图像超分 python实现 图像超分辨率代码_Image_66的对角,如00对11,10对01)之间的矩形面积。权重由传统图像超分 python实现 图像超分辨率代码_深度学习_67归一化。特征图传统图像超分 python实现 图像超分辨率代码_深度学习_16在边界外是镜像填充的,因此这也适用于边界附近的坐标。

这是为了让由隐码表示的局部图像块与其相邻块重叠,使得在每个坐标处有四个隐码用于独立预测信号。然后,这四个预测通过用归一化置信度投票来合并,归一化置信度与查询点和其最近的隐码对角对应点之间的矩形面积成比例,因此当查询坐标更近时,置信度变得更高。通过这种投票,它在传统图像超分 python实现 图像超分辨率代码_计算机视觉_69转换坐标(即图中的虚线)处实现了连续过渡。

vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6

rx = 2 / feat.shape[-2] / 2  #2/H/2
ry = 2 / feat.shape[-1] / 2  #2/W/2

feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()#[N,SR_H*SR_W,2]
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

        q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
        q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

        q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
        q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

        rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

        if self.cell_decode:
            rel_cell = cell.clone()
            rel_cell[:, :, 0] *= feat.shape[-2]
            rel_cell[:, :, 1] *= feat.shape[-1]
            inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

        bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
        #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
        pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
        preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
if self.local_ensemble:
    t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
    t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])

Cell decoding

为了LIIF能够表示基于像素形式的任意分辨率呈现,假设给定了所需分辨率,一种简单方法是查询连续表示传统图像超分 python实现 图像超分辨率代码_深度学习_70中像素中心坐标处的RGB值,但因为查询像素的预测RGB值与其大小无关,其像素区域中的信息除了中心值都被丢弃,可能不是最佳的。
传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_71
传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_72包含指定查询像素的高度和宽度两个值,传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_73是值传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_74传统图像超分 python实现 图像超分辨率代码_深度学习_75的连接(concatenation),传统图像超分 python实现 图像超分辨率代码_深度学习_75是附加输入。
传统图像超分 python实现 图像超分辨率代码_计算机视觉_77能理解为使用形状传统图像超分 python实现 图像超分辨率代码_深度学习_75渲染以坐标传统图像超分 python实现 图像超分辨率代码_传统图像超分 python实现_74为中心的像素的RGB值。对于传统图像超分 python实现 图像超分辨率代码_ci_80的分辨率,传统图像超分 python实现 图像超分辨率代码_深度学习_75是图像宽度的传统图像超分 python实现 图像超分辨率代码_ci_82。逻辑上,当传统图像超分 python实现 图像超分辨率代码_ci_83时,传统图像超分 python实现 图像超分辨率代码_Image_84,即连续图像可以看作像素无限小的图像。

cell = torch.ones_like(coord) #[SR_H*SR_W,2] [1*2/SR_H,1*2/SR_W]
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w 

if self.cell_decode:
     rel_cell = cell.clone()
     rel_cell[:, :, 0] *= feat.shape[-2]
     rel_cell[:, :, 1] *= feat.shape[-1]
     inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

LIIF class 完全代码

class LIIF(nn.Module):
    def __init__(self, encoder_spec, imnet_spec=None,
                 local_ensemble=True, feat_unfold=True, cell_decode=True):
        super().__init__()
        self.local_ensemble = local_ensemble
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode
        self.encoder = models.make(encoder_spec)

        #print("self.encoder.out_dim",self.encoder.out_dim)
        if imnet_spec is not None:
            imnet_in_dim = self.encoder.out_dim     #64
            if self.feat_unfold:
                imnet_in_dim *= 9
            imnet_in_dim += 2 # attach coord 指定查询像素的坐标 [x,y]
            if self.cell_decode:
                imnet_in_dim += 2 #[Cell_h, Cell_w]指定查询像素的高度和宽度的两个值
            self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
        else:
            self.imnet = None

    def gen_feat(self, inp):
        self.feat = self.encoder(inp)
        return self.feat

    def query_rgb(self, coord, cell=None):
        #coord [N,SR_H*SR_*W,2]
        #cell [N,SR_H*SR_*W,2]
        feat = self.feat #[N,C,LR_H,LR_W]

        if self.imnet is None:
            ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
            ret = ret[:, :, 0, :].permute(0, 2, 1)
            return ret

        if self.feat_unfold:
            # [N,C*3*3,H,W]
            feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0

        # field radius (global: [-1, 1])
        rx = 2 / feat.shape[-2] / 2  #2/H/2
        ry = 2 / feat.shape[-1] / 2  #2/W/2

        feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
        feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                coord_ = coord.clone()#[N,SR_H*SR_W,2]
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
                q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

                q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
                q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

                rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

                if self.cell_decode:
                    rel_cell = cell.clone()
                    rel_cell[:, :, 0] *= feat.shape[-2]
                    rel_cell[:, :, 1] *= feat.shape[-1]
                    inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

                bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
                #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
                pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
                preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


        tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
        if self.local_ensemble:
            t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
            t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        return ret

    def forward(self, inp, coord, cell):
        self.gen_feat(inp)
        return self.query_rgb(coord, cell)