# -*- coding: utf-8 -*-
"""
Created on Tue Jul 3 10:27:15 2018
@author: guokai_liu
"""
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.init
import itertools
#%%
parameters = [[12,3,2,1],
[4,3,1,0]]
idx = 0
i = parameters[idx][0]
k = parameters[idx][1]
s = parameters[idx][2]
p = parameters[idx][3]
#%%
# exact output size can be also specified as an argument
i = np.arange(i**2).astype('float32').reshape((1,1,i,i))
input = torch.from_numpy(i)
#input = torch.randn(1, 1, 12, 12)
downsample = nn.Conv2d(1, 1, k, stride=s, padding=p)
#w = torch.from_numpy(np.arange(9).reshape((3,3)))
#w = torch.from_numpy(np.arange(9).reshape((3,3)))
w =1
downsample.weight.data.fill_(w)
h = downsample(input)
h.size()
#%%
upsample = nn.ConvTranspose2d(1, 1, k, stride=s, padding=p)
output = upsample(h, output_size=input.size())
output.size()
#%%
data_input = input.numpy()[0,0,:,:]
data_hidden = h.detach().numpy()[0,0,:,:]
data_upsample = output.detach().numpy()[0,0,:,:]
data_conv = downsample.weight.data.numpy()[0,0,:,:]
#%%
def show_data(data,ax):
# fmt = '.2f'
fmt = '.0f'
cm = data
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
va = "center",
color ="white")
#%%
fig = plt.figure(figsize=(18,6))
ax1 = plt.subplot(141)
ax1.matshow(data_input)
ax1.set_title('Input',pad=20)
show_data(data_input,ax1)
ax2 = plt.subplot(142)
ax2.matshow(data_conv)
ax2.set_title('Conv_kernel',pad=20)
show_data(data_conv,ax2)
ax3 = plt.subplot(143)
ax3.matshow(data_hidden)
ax3.set_title('Conv Output',pad=20)
show_data(data_hidden,ax3)
ax4 = plt.subplot(144)
ax4.matshow(data_upsample)
ax4.set_title('Transposed Output',pad=20)
show_data(data_upsample,ax4)
plt.tight_layout()