관리 메뉴

TEAM EDA

Rethinking Atrous Convolution for Semantic Image Segmentation (DeepLabv3) Code 본문

EDA Study/Image Segmentation

Rethinking Atrous Convolution for Semantic Image Segmentation (DeepLabv3) Code

김현우 2021. 9. 23. 17:38
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from types import ModuleType
class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch, stride, dilation, downsample):
        super(Bottleneck, self).__init__()
        mid_ch = out_ch // 4
        self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=mid_ch, kernel_size=1, stride=1, padding=0, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=mid_ch)
        self.conv2 = nn.Conv2d(in_channels=mid_ch, out_channels=mid_ch, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=mid_ch)
        self.conv3 = nn.Conv2d(in_channels=mid_ch, out_channels=out_ch, kernel_size=1, stride=1, padding=0, dilation=dilation, bias=False)
        self.bn3 = nn.BatchNorm2d(num_features=out_ch)
        self.relu = nn.ReLU(inplace=True)
        if downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, stride=stride, padding=0, dilation=dilation, bias=False), 
                nn.BatchNorm2d(num_features=out_ch)
            )
            self.is_downsample = True
        else:
            self.is_downsample = False


    def forward(self, x):
        h = self.bn1(self.conv1(x))
        h = self.bn2(self.conv2(h))
        h = self.bn3(self.conv3(h))
        h = self.relu(h)
        if self.is_downsample: 
            h += self.downsample(x)
        else:
            h += x 
        return h

class ResLayer(nn.Sequential):
    def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grid=0):
        super(ResLayer, self).__init__()
        multi_grids = [1, 2, 2] if multi_grid else [1 for _ in range(n_layers)]
        for i in range(n_layers):
            self.add_module(
                "{}".format(i),
                Bottleneck(
                    in_ch=(in_ch if i == 0 else out_ch), 
                    out_ch=out_ch,
                    stride=(stride if i == 0 else 1),
                    dilation=dilation * multi_grids[i],
                    downsample=(True if i == 0 else False), # Downsampling is only in the first block (i=0)
                ),
            )  


class IntermediateLayerGetter(nn.Sequential):
    def __init__(self, n_blocks, ch, atrous_rates, output_stride):
        super(IntermediateLayerGetter, self).__init__()
        # Stride and dilation
        if output_stride == 8:
            s = [1, 2, 1, 1]
            d = [1, 1, 2, 4]
        elif output_stride == 16:
            s = [1, 2, 2, 1]
            d = [1, 1, 1, 2]

        self.add_module("conv1", nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False))
        self.add_module("bn1", nn.BatchNorm2d(64))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("maxpool", nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1))
        self.add_module("layer1", ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0], 0))
        self.add_module("layer2", ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1], 0))
        self.add_module("layer3", ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2], 0))
        self.add_module("layer4", ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], 1)) # multi_grid 넣을라면 1 아니면 0

class ASPPConv(nn.Module):
    def __init__(self, inplanes, outplanes, kernel_size, padding, dilation):
        super(ASPPConv, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        return self.relu(x)

class ASPPPooling(nn.Module):
    def __init__(self, inplanes, outplanes):
        super(ASPPPooling, self).__init__()
        self.globalavgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.globalavgpool(x)
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)


class ASPP(nn.Module):
    def __init__(self, inplanes, outplanes):
        super(ASPP, self).__init__()
        dilations = [1, 6, 12, 18]
        self.aspp1 = ASPPConv(inplanes, outplanes, 1, padding=0, dilation=dilations[0])
        self.aspp2 = ASPPConv(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = ASPPConv(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = ASPPConv(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3])
        self.global_avg_pool = ASPPPooling(inplanes, outplanes)
        self.project = nn.Sequential(
            nn.Conv2d(outplanes*5, outplanes, 1, bias=False), 
            nn.BatchNorm2d(outplanes), 
            nn.ReLU(), 
            nn.Dropout(0.5)      
        )

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.project(x)
        return x

class DeepLabHead(nn.Sequential):
    def __init__(self, ch, out_ch, n_classes):
        super(DeepLabHead, self).__init__()
        self.add_module("0", ASPP(ch[-1], out_ch))
        self.add_module("1", nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1 , bias=False))
        self.add_module("2", nn.BatchNorm2d(out_ch))
        self.add_module("3", nn.ReLU())
        self.add_module("4", nn.Conv2d(out_ch, n_classes, kernel_size=1, stride=1))


class DeepLabV3(nn.Sequential):
    def __init__(self, n_classes, n_blocks, atrous_rates):
        super(DeepLabV3, self).__init__()
        ch = [64 * 2 ** p for p in range(6)]
        self.backbone = IntermediateLayerGetter(n_blocks=[3, 4, 23, 3], ch=[64 * 2 ** p for p in range(6)], atrous_rates=[6, 12, 18, 24], output_stride=16)
        self.classifier = DeepLabHead(ch=[64 * 2 ** p for p in range(6)], out_ch=256, n_classes=21)

    def forward(self, x): 
        h = self.backbone(x)
        h = self.classifier(h)
        h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
        return h