Ref:

​A guide to convolution arithmetic for deep learning​

​Github-Convolution arithmetic​

【注意1】在Pytorch的转置卷积层​​nn.ConvTranspose2d​​中:kernel与stride与之前卷积层的尺寸参数相同,但是需要进行补零操作。

180704 Pytorch转置矩阵计算可视化_卷积


【注意2】转置卷积并不保证处理的张量能够恢复其进卷积之前的数值,只保证能够恢复卷积之前的size。

180704 Pytorch转置矩阵计算可视化_转置_02


【注意3】转置卷积可理解为对卷积输出补零再卷积,但实际操作中为了提高计算的效率并不是这样做。

  • 卷积操作:蓝色为输入,绿色为输出,深蓝色为卷积核
  • 转置卷积操作:蓝色为输入,绿色为输出,灰色为卷积核

示例1:(input=12x12,kernel= 3x3,stride=2,padding=1)

180704 Pytorch转置矩阵计算可视化_转置_03

示例2:(input=4x4,kernel= 3x3,stride=1,padding=0)

180704 Pytorch转置矩阵计算可视化_卷积_04

# -*- coding: utf-8 -*-
"""
Created on Tue Jul 3 10:27:15 2018

@author: guokai_liu
"""
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.init
import itertools



#%%
parameters = [[12,3,2,1],
[4,3,1,0]]
idx = 0
i = parameters[idx][0]
k = parameters[idx][1]
s = parameters[idx][2]
p = parameters[idx][3]

#%%
# exact output size can be also specified as an argument
i = np.arange(i**2).astype('float32').reshape((1,1,i,i))
input = torch.from_numpy(i)
#input = torch.randn(1, 1, 12, 12)
downsample = nn.Conv2d(1, 1, k, stride=s, padding=p)
#w = torch.from_numpy(np.arange(9).reshape((3,3)))
#w = torch.from_numpy(np.arange(9).reshape((3,3)))
w =1
downsample.weight.data.fill_(w)
h = downsample(input)
h.size()
#%%
upsample = nn.ConvTranspose2d(1, 1, k, stride=s, padding=p)
output = upsample(h, output_size=input.size())
output.size()
#%%
data_input = input.numpy()[0,0,:,:]
data_hidden = h.detach().numpy()[0,0,:,:]
data_upsample = output.detach().numpy()[0,0,:,:]
data_conv = downsample.weight.data.numpy()[0,0,:,:]
#%%
def show_data(data,ax):
# fmt = '.2f'
fmt = '.0f'
cm = data
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
va = "center",
color ="white")

#%%
fig = plt.figure(figsize=(18,6))
ax1 = plt.subplot(141)
ax1.matshow(data_input)
ax1.set_title('Input',pad=20)
show_data(data_input,ax1)

ax2 = plt.subplot(142)
ax2.matshow(data_conv)
ax2.set_title('Conv_kernel',pad=20)
show_data(data_conv,ax2)

ax3 = plt.subplot(143)
ax3.matshow(data_hidden)
ax3.set_title('Conv Output',pad=20)
show_data(data_hidden,ax3)


ax4 = plt.subplot(144)
ax4.matshow(data_upsample)
ax4.set_title('Transposed Output',pad=20)
show_data(data_upsample,ax4)

plt.tight_layout()