Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- Recsys-KR
- 프로그래머스
- 엘리스
- Semantic Segmentation
- TEAM-EDA
- 입문
- 3줄 논문
- 나는리뷰어다
- Segmentation
- MySQL
- 한빛미디어
- 추천시스템
- 나는 리뷰어다
- Machine Learning Advanced
- pytorch
- 큐
- DFS
- Object Detection
- 협업필터링
- Image Segmentation
- 알고리즘
- 스택
- Python
- DilatedNet
- hackerrank
- TEAM EDA
- eda
- 코딩테스트
- 파이썬
- 튜토리얼
Archives
- Today
- Total
TEAM EDA
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs (DeepLabv2) Code 본문
EDA Study/Image Segmentation
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs (DeepLabv2) Code
김현우 2021. 9. 23. 16:37DeepLabv2의 경우 vgg16과 ResNet 두개의 버전이 있는데, 아래는 ResNet101로 구현했습니다.
#!/usr/bin/env python
# coding: utf-8
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from types import ModuleType
class Bottleneck(nn.Module):
def __init__(self, in_ch, out_ch, stride, dilation, downsample):
super(Bottleneck, self).__init__()
mid_ch = out_ch // 4
self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=mid_ch, kernel_size=1, stride=1, padding=0, dilation=dilation, bias=False)
self.bn1 = nn.BatchNorm2d(num_features=mid_ch)
self.conv2 = nn.Conv2d(in_channels=mid_ch, out_channels=mid_ch, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(num_features=mid_ch)
self.conv3 = nn.Conv2d(in_channels=mid_ch, out_channels=out_ch, kernel_size=1, stride=1, padding=0, dilation=dilation, bias=False)
self.bn3 = nn.BatchNorm2d(num_features=out_ch)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, stride=stride, padding=0, dilation=dilation, bias=False),
nn.BatchNorm2d(num_features=out_ch)
)
self.is_downsample = True
else:
self.is_downsample = False
def forward(self, x):
h = self.bn1(self.conv1(x))
h = self.bn2(self.conv2(h))
h = self.bn3(self.conv3(h))
h = self.relu(h)
if self.is_downsample:
h += self.downsample(x)
else:
h += x
return h
class ResLayer(nn.Sequential):
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grid=0):
super(ResLayer, self).__init__()
multi_grids = [1, 2, 2] if multi_grid else [1 for _ in range(n_layers)]
for i in range(n_layers):
self.add_module(
"{}".format(i),
Bottleneck(
in_ch=(in_ch if i == 0 else out_ch),
out_ch=out_ch,
stride=(stride if i == 0 else 1),
dilation=dilation * multi_grids[i],
downsample=(True if i == 0 else False), # Downsampling is only in the first block (i=0)
),
)
class IntermediateLayerGetter(nn.Sequential):
def __init__(self, n_blocks, ch, atrous_rates, output_stride):
super(IntermediateLayerGetter, self).__init__()
# Stride and dilation
if output_stride == 8:
s = [1, 2, 1, 1]
d = [1, 1, 2, 4]
elif output_stride == 16:
s = [1, 2, 2, 1]
d = [1, 1, 1, 2]
self.add_module("conv1", nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False))
self.add_module("bn1", nn.BatchNorm2d(64))
self.add_module("relu", nn.ReLU(inplace=True))
self.add_module("maxpool", nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1))
self.add_module("layer1", ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0], 0))
self.add_module("layer2", ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1], 0))
self.add_module("layer3", ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2], 0))
self.add_module("layer4", ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], 1)) # multi_grid 넣을라면 1 아니면 0
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, n_classes=21):
super(ASPP, self).__init__()
# atrous 3x3, rate=6
self.conv_3x3_r6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6)
# atrous 3x3, rate=12
self.conv_3x3_r12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12)
# atrous 3x3, rate=18
self.conv_3x3_r18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18)
# atrous 3x3, rate=24
self.conv_3x3_r24 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=24, dilation=24)
self.drop_conv_3x3 = nn.Dropout2d(0.5)
self.conv_1x1 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
self.drop_conv_1x1 = nn.Dropout2d(0.5)
self.conv_1x1_out = nn.Conv2d(out_channels, n_classes, kernel_size=1)
def forward(self, feature_map):
# 1번 branch
# shape: (batch_size, out_channels, height/output_stride, width/output_stride)
out_3x3_r6 = self.drop_conv_3x3(F.relu(self.conv_3x3_r6(feature_map)))
out_img_r6 = self.drop_conv_1x1(F.relu(self.conv_1x1(out_3x3_r6)))
out_img_r6 = self.conv_1x1_out(out_img_r6)
# 2번 branch
# shape: (batch_size, out_channels, height/output_stride, width/output_stride)
out_3x3_r12 = self.drop_conv_3x3(F.relu(self.conv_3x3_r12(feature_map)))
out_img_r12 = self.drop_conv_1x1(F.relu(self.conv_1x1(out_3x3_r12)))
out_img_r12 = self.conv_1x1_out(out_img_r12)
# 3번 branch
# shape: (batch_size, out_channels, height/output_stride, width/output_stride)
out_3x3_r18 = self.drop_conv_3x3(F.relu(self.conv_3x3_r18(feature_map)))
out_img_r18 = self.drop_conv_1x1(F.relu(self.conv_1x1(out_3x3_r18)))
out_img_r18 = self.conv_1x1_out(out_img_r18)
# 4번 branch
# shape: (batch_size, out_channels, height/output_stride, width/output_stride)
out_3x3_r24 = self.drop_conv_3x3(F.relu(self.conv_3x3_r24(feature_map)))
out_img_r24 = self.drop_conv_1x1(F.relu(self.conv_1x1(out_3x3_r24)))
out_img_r24 = self.conv_1x1_out(out_img_r24)
out = sum([out_img_r6, out_img_r12, out_img_r18, out_img_r24])
return out
class DeepLabV2(nn.Sequential):
def __init__(self, n_classes=21):
super(DeepLabV2, self).__init__()
ch = [64 * 2 ** p for p in range(6)]
self.backbone = IntermediateLayerGetter(n_blocks=[3, 4, 23, 3], ch=[64 * 2 ** p for p in range(6)], atrous_rates=[6, 12, 18, 24], output_stride=16)
self.classifier = ASPP(2048, 256, n_classes)
def forward(self, x):
h = self.backbone(x)
h = self.classifier(h)
h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
return h