class BasicBlockGroup(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1, groups=2, bn=False):
    super(BasicBlockGroup, self).__init__()
    self.bn = bn
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
    if bn:
        self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                           stride=1, padding=1, bias=False, groups=groups)
    if bn:
        self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion*planes:
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_planes, self.expansion*planes,
                      kernel_size=1, stride=stride, bias=False, groups=groups),
            # nn.BatchNorm2d(self.expansion*planes)
        )

def forward(self, x):
    if self.bn:
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
    else:
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))

    out += self.shortcut(x)
    out = F.relu(out)
    return out