티스토리 뷰
728x90
1. U-net 이란?
Semantic Segmentation에 가장 기본적으로 사용되던 모델 (U-net)
모델의 형태가 U자로 되어있어서 U-net이라고 불림
U-net의 모델 구조는 크게 Encoder, Decoder로 이루어져있다.
Encoder(= Contracting Path), Decoder(= Expanding Path) 라고도 한다.
또한, Decoder 구간에서 Upsampling을 할 때, Localization을 좀 더 정확하게 하기 위해서 Encoder의 Feature들을 Concat을 한다.
즉 간단히 말해서, Encoder에서는 이미지의 크기를 줄여가면서 Feature는 더 많이 뽑아내며 나아가는 수축의 과정을 진행하고
Decoder에서는 이렇게 추출해온 Feature들을 기반으로 다시 원래의 이미지 크기로 확장을 해나가는 과정을 진행한다.
2. U-net 모델의 구조
그림으로 볼 수 있듯이 왼쪽이 Encoder, 오른쪽이 Decoder로 볼 수 있으며
맨 아래 가운데부분은 BottleNeck (혹은 전환점)이라고 불린다.
3. U-net 모델 구현
- Padding을 사용하여 Conv 연산이 일어나는 동안에는 이미지 사이즈에 변동이 없도록 하였다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unet(nn.Module):
def __init__(self):
super(Unet,self).__init__()
self.normal_pooling=nn.MaxPool2d(kernel_size=2) # Encoder Pooling Net
# Encoder
self.encoder_1=self.encoder_block(1,64)
self.encoder_2=self.encoder_block(64,128)
self.encoder_3=self.encoder_block(128,256)
self.encoder_4=self.encoder_block(256,512)
self.bottleneck=self.decoder_block(512,1024,512) # Bottleneck
# Decoder
self.un_pooling_4=self.un_pool(512,512)
self.decoder_4=self.decoder_block(1024,512,256)
self.un_pooling_3=self.un_pool(256,256)
self.decoder_3=self.decoder_block(512,256,128)
self.un_pooling_2=self.un_pool(128,128)
self.decoder_2=self.decoder_block(256,128,64)
self.un_pooling_1=self.un_pool(64,64)
self.decoder_1=self.decoder_block(128,64,64)
self.output_net=nn.Sequential(nn.Conv2d(in_channels=64,out_channels=1,kernel_size=(1,1),stride=(1,1),padding=0))
def un_pool(self,in_channel,out_channel):
return nn.Sequential(nn.ConvTranspose2d(in_channels=in_channel,
out_channels=out_channel,
kernel_size=2,
stride=2,
padding=0))
def encoder_block(self,in_channel,out_channel):
layers=[]
prev_channel=in_channel
for _ in range(2):
layers.append(nn.Conv2d(in_channels=prev_channel,
out_channels=out_channel,
kernel_size=(3,3),
stride=(1,1),
padding=(1,1)))
layers.append(nn.BatchNorm2d(out_channel))
layers.append(nn.ReLU(True))
prev_channel=out_channel
return nn.Sequential(*layers)
def decoder_block(self,in_channel,hidden_channel,out_channel):
layers=[]
layers.append(nn.Conv2d(in_channels=in_channel,
out_channels=hidden_channel,
kernel_size=(3,3),
stride=(1,1),
padding=(1,1)))
layers.append(nn.BatchNorm2d(hidden_channel))
layers.append(nn.ReLU(True))
layers.append(nn.Conv2d(in_channels=hidden_channel,
out_channels=out_channel,
kernel_size=(3,3),
stride=(1,1),
padding=(1,1)))
layers.append(nn.BatchNorm2d(out_channel))
layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
def forward(self,x):
x=self.encoder_1(x)
encoder_x1=x
x=self.normal_pooling(x)
x=self.encoder_2(x)
encoder_x2=x
x=self.normal_pooling(x)
x=self.encoder_3(x)
encoder_x3=x
x=self.normal_pooling(x)
x=self.encoder_4(x)
encoder_x4=x
x=self.normal_pooling(x)
x=self.bottleneck(x)
x=self.un_pooling_4(x)
x=self.decoder_4(torch.cat([x,encoder_x4],dim=1))
x=self.un_pooling_3(x)
x=self.decoder_3(torch.cat([x,encoder_x3],dim=1))
x=self.un_pooling_2(x)
x=self.decoder_2(torch.cat([x,encoder_x2],dim=1))
x=self.un_pooling_1(x)
x=self.decoder_1(torch.cat([x,encoder_x1],dim=1))
x=self.output_net(x)
return x
def init_params(self) :
for m in self.modules() :
if isinstance(m,nn.Conv2d) :
nn.init.kaiming_normal_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d) :
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
|
cs |
전체 소스코드는
github.com/hunmin-hub/DL_Tutorial/tree/main/Unet
'Deep Learning > CV' 카테고리의 다른 글
롯데정보통신 Vision AI 경진대회 - Public LB 2nd place Solution (5) | 2021.03.28 |
---|---|
[DACON] Mnist 숫자 사진 분류 (0) | 2021.02.12 |
[DACON] Mnist Fashion 의류 사진 분류 (0) | 2021.02.12 |
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- 부스트캠프 AI Tech
- Vision AI 경진대회
- C++
- 데이터연습
- P-Stage
- 네트워킹데이
- ResNet
- 공공데이터
- 이분탐색
- 프로그래머스
- 데이터핸들링
- Unet
- 알고리즘
- 백트래킹
- 코딩테스트
- DACON
- python
- 브루트포스
- dfs
- 다이나믹프로그래밍
- Unet 구현
- 동적계획법
- 백준
- Data Handling
- NLP 구현
- 그리디
- cnn
- DeepLearning
- AI 프로젝트
- pandas
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
글 보관함