티스토리 뷰

Deep Learning/CV

U-net 구현

dev.hunmin 2021. 3. 7. 20:52
728x90

1. U-net 이란?

Semantic Segmentation에 가장 기본적으로 사용되던 모델 (U-net)

모델의 형태가 U자로 되어있어서 U-net이라고 불림

 

U-net의 모델 구조는 크게 Encoder, Decoder로 이루어져있다.

Encoder(= Contracting Path), Decoder(= Expanding Path) 라고도 한다.

 

또한, Decoder 구간에서 Upsampling을 할 때, Localization을 좀 더 정확하게 하기 위해서 Encoder의 Feature들을 Concat을 한다.

 

즉 간단히 말해서, Encoder에서는 이미지의 크기를 줄여가면서 Feature는 더 많이 뽑아내며 나아가는 수축의 과정을 진행하고

Decoder에서는 이렇게 추출해온 Feature들을 기반으로 다시 원래의 이미지 크기로 확장을 해나가는 과정을 진행한다.

 

2. U-net 모델의 구조

U-net

그림으로 볼 수 있듯이 왼쪽이 Encoder, 오른쪽이 Decoder로 볼 수 있으며 

맨 아래 가운데부분은 BottleNeck (혹은 전환점)이라고 불린다.

 

 

3. U-net 모델 구현

- Padding을 사용하여 Conv 연산이 일어나는 동안에는 이미지 사이즈에 변동이 없도록 하였다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.normal_pooling=nn.MaxPool2d(kernel_size=2# Encoder Pooling Net
        # Encoder
        self.encoder_1=self.encoder_block(1,64)
        self.encoder_2=self.encoder_block(64,128)
        self.encoder_3=self.encoder_block(128,256)
        self.encoder_4=self.encoder_block(256,512)
 
        self.bottleneck=self.decoder_block(512,1024,512# Bottleneck
 
        # Decoder
        self.un_pooling_4=self.un_pool(512,512)
        self.decoder_4=self.decoder_block(1024,512,256)
        self.un_pooling_3=self.un_pool(256,256)
        self.decoder_3=self.decoder_block(512,256,128)
        self.un_pooling_2=self.un_pool(128,128)
        self.decoder_2=self.decoder_block(256,128,64)
        self.un_pooling_1=self.un_pool(64,64)
        self.decoder_1=self.decoder_block(128,64,64)
 
        self.output_net=nn.Sequential(nn.Conv2d(in_channels=64,out_channels=1,kernel_size=(1,1),stride=(1,1),padding=0))
 
    def un_pool(self,in_channel,out_channel):
        return nn.Sequential(nn.ConvTranspose2d(in_channels=in_channel,
                                                out_channels=out_channel,
                                                kernel_size=2,
                                                stride=2,
                                                padding=0))
 
    def encoder_block(self,in_channel,out_channel):
        layers=[]
        prev_channel=in_channel
        for _ in range(2):
            layers.append(nn.Conv2d(in_channels=prev_channel,
                                    out_channels=out_channel,
                                    kernel_size=(3,3),
                                    stride=(1,1),
                                    padding=(1,1)))
            layers.append(nn.BatchNorm2d(out_channel))
            layers.append(nn.ReLU(True))
            prev_channel=out_channel
        return nn.Sequential(*layers)
 
    def decoder_block(self,in_channel,hidden_channel,out_channel):
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channel,
                                    out_channels=hidden_channel,
                                    kernel_size=(3,3),
                                    stride=(1,1),
                                    padding=(1,1)))
        layers.append(nn.BatchNorm2d(hidden_channel))
        layers.append(nn.ReLU(True))
        layers.append(nn.Conv2d(in_channels=hidden_channel,
                                    out_channels=out_channel,
                                    kernel_size=(3,3),
                                    stride=(1,1),
                                    padding=(1,1)))
        layers.append(nn.BatchNorm2d(out_channel))
        layers.append(nn.ReLU(True))
        return nn.Sequential(*layers)
    
    def forward(self,x):
        x=self.encoder_1(x)
        encoder_x1=x
        x=self.normal_pooling(x)
        x=self.encoder_2(x)
        encoder_x2=x
        x=self.normal_pooling(x)
        x=self.encoder_3(x)
        encoder_x3=x
        x=self.normal_pooling(x)
        x=self.encoder_4(x)
        encoder_x4=x
        x=self.normal_pooling(x)
        x=self.bottleneck(x)
        x=self.un_pooling_4(x)
        x=self.decoder_4(torch.cat([x,encoder_x4],dim=1))
        x=self.un_pooling_3(x)
        x=self.decoder_3(torch.cat([x,encoder_x3],dim=1))
        x=self.un_pooling_2(x)
        x=self.decoder_2(torch.cat([x,encoder_x2],dim=1))
        x=self.un_pooling_1(x)
        x=self.decoder_1(torch.cat([x,encoder_x1],dim=1))
        x=self.output_net(x)
        return x
    
    def init_params(self) :
        for m in self.modules() :
            if isinstance(m,nn.Conv2d) :
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m,nn.BatchNorm2d) :
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
cs

 

 

전체 소스코드는

github.com/hunmin-hub/DL_Tutorial/tree/main/Unet

 

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/11   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
글 보관함