U-Net: Convolutional Networks for Biomedical Image Segmentation 논문 요약
Abstract
좋은 딥러닝 네트워크를 구축하기 위해서는 수많은 데이터가 필요하다. 본 논문에서는 데이터 증강에 의존하여 주어진 데이터를 효율적으로 사용한 네트워크와 학습 방법에 대해 소개한다. 네트워크는 특징을 추출하는 contracting(수축) path와 localization을 수행하는 symmetric(확장) path로 이루어져 있다. 또한, 기존보다 빠른 네트워크로 ISBI cell tracking challenge 2015에서 큰 격차로 우승했다.
Details
도입부
- AlexNet 연구를 통해 8 layers의 깊이를 가진 복잡한 네트워크를 백만개에 달하는 ImageNet 데이터를 사용하여 학습시키면서 이전보다 더 깊고 큰 네트워크를 학습시킬 수 있다는 것이 입증되었다.
- 일반적으로 CNN은 분류 task에서 사용되었지만 많은 CV 분야에서(특히, 의학 분야)는 localization이 꼭 필요하다. 따라서, Ciresan은 sliding window를 통해 localization을 가능하게 했고, patch 개념을 활용하여 기존의 훈련 데이터보다 더 많은 데이터를 사용하는 효과를 냈다.
- 하지만, Ciresan의 연구에서는 네트워크가 각각의 패치마다 따로 학습되어 학습이 느리다는 것과 localization accuracy와 context의 파악 관계가 trade-off 한다는 한계가 존재했다. localization accuracy를 높이기 위해 작은 patches를 사용하면 이미지 영역 전체에 대해 context를 잘 파악하기 힘들다는 것이 예시이다. 반대로, patches를 크게 사용하면 이미지의 패턴은 잘 파악하겠지만, 세분화된 localization을 수행하기 어렵다.
- 본 논문에서는 Ciresan의 네트워크를 수정한 fully convolutional network라고 불리는 네트워크를 사용하여 아주 적은 이미지로 더 좋은 segmentation을 수행하는 모델을 만들었다. 이 네트워크는 pooling 연산자가 upsampling 연산자로 대체되어 contracting network를 보완한 것이 핵심이다. 또한, localization을 잘 수행하기 위해 높은 해상도를 갖는 contracting path의 특징맵을 upsampling path의 특징맵과 결합했다.
- upsampling 구간에 깊은 채널을 사용하여 context 정보가 higher resolution layers에 잘 전달되게 했다.
- 분류가 아닌 Segmentation이 필요하기때문에 fully connected layer보단 fully convolutional layer가 더 적합하다. Fully connected layer는 피처맵을 일자로 쭉펴서 공간정보가 사라지는 반면, fully convolutional layer는 합성곱을 사용해서 공간정보를 여전히 유지할 수 있기때문에 segmentation과 같이 보다 정밀한 task에는 특징 추출 후 분류기에서 fully convolutional layer를 사용한다.
- 하나의 이미지를 작은 patches 단위로 분할해서 연산을 수행한 뒤 결과를 합치는 방법을 사용했다. 이미지를 patches 단위로 분할하면 겹치는 부분이 생기게되는데 겹치는 부분에 대해서는 정보를 공유하고 결합하여 정확성을 향상시킬 수 있다.
- U-net의 구조상 패딩을 사용하지 않아 각 layers마다 3x3 필터를 거치며 나온 출력 이미지의 크기는 입력 이미지의 크기보다 작게되는데 이 문제를 해결하기 위해 작아진 이미지의 테두리 부분에 미러링을 적용하여 기존의 이미지와 같은 크기의 이미지로 만드는 방법을 사용했다. 아래의 왼쪽 그림은 오른쪽처럼 입력보다 이미지가 작아졌을 때 미러링 기법을 사용하여 크기를 복원한 예시이다.
- 데이터가 많지 않고 세포 구조상 elastic deformation(랜덤하게 이미지를 뒤틀리게 변형)을 수행해도 현실적인 이미지가 만들어지기때문에 데이터 증강을 활용하여 성능을 높였다.
- 세포 간 경계(배경)을 확실히 하기 위해서 인접한 세포가 가까우면 가중치를 높이고, 세포와의 거리가 멀어질수록 가중치를 줄이는 손실 함수를 사용하여 경계를 명확히 구분했다.
Network Architecture
- Figure 1 그림을 참고하여 왼쪽을 contracting path, 오른쪽을 expansive path라고 한다. Contracting path는 일반적인 convolutional 구조를 따르고 각 layers마다 3x3 필터를 2번 반복한다. 그후, ReLU와 2x2 max pooling(stride 2)을 사용한다. downsampling을 할 때마다 채널은 2배가 된다.
- Upsampling이 이루어지는 expansive path에서는 2x2 convolution을 사용하고 채널을 절반으로 줄인다. Concatenation이 이루어지는 구간에서는 contracing path의 피처맵의 크기를 잘라 피처맵의 크기를 expansive와 동일하게 맞춘 뒤 병합한다. 병합후에는 3x3 필터가 2번 사용되고 ReLU 함수를 통과한다.
- 마지막 layer에서는 1x1 필터를 사용하여 64개의 채널을 클래스 수만큼 변형시킨다.
- 네트워크는 총 23개의 convolutional layers로 이루어져있다.
- Patches별로 분할을 잘 하기 위해서 2x2 max-pooling을 사용할 때 이미지의 크기가 짝수가 되어야한다.
Training
1. Training Methodology
- Optimization SGD, Momentum 0.99 사용
- GPU를 최대한 활용하기 위해 많은 이미지보다 하나의 이미지를 타일로 분할해서 사용했다. 즉, 입력 데이터의 단위가 이미지가 아닌 이미지를 분할시킨 타일이 된다.
- Segmentation task이기 때문에 loss function으로는 픽셀마다 softmax를 적용하고 이를 cross entropy에 혼합하여 사용한다.
- $K$는 클래스 수, $k$는 특정 클래스, $X$는 픽셀의 위치(x,y니까 2차원 위치), $a_{k}(X)$는 $k$번째 채널의 $X$ 위치의 출력값이다. 쉽게 말해, 전체 클래스에 속할 확률로 해당 위치가 특정 클래스에 속할 확률을 나눈 것이다. 만약, 해당 위치의 클래스가 5일 때, 5일 확률을 100%로 정확하게 맞췄다면 1/1로 $p_{k}(X)=1$이 될 것이다.
- 해당 확률값이 구해졌으면, 이를 cross entropy에 적용하여 error를 구할 수 있다.
- 위에서 구한 확률값 $p_{l(x)}(X)$에 log를 씌워 손실을 계산한다. log를 씌우면 곱셈을 덧셈으로 변환할 수 있어 연산이 간단해지고 매우 크거나 작은 수의 격차가 줄어들어 연산에 안정성이 생긴다는 장점이 있다. 위의 식을 보면 일반적인 cross entropy와는 달리 결과값에 곱해지는 $w(X)$라는 가중치가 있는데 이 가중치는 다음과 같이 계산할 수 있다.
- 위의 식에서 $w_{c}$는 클래스의 빈도, $d_{1}$은 가장 가까운 세포와의 거리, $d_{2}$는 두번째로 가까운 세포와의 거리이다. 지수함수 exp()에 마이너스가 있기때문에 인접한 세포와의 거리가 가까울수록 가중치가 커지고, 멀수록 가중치가 작아진게 된다. 즉, 해당 위치가 셀과 얼마나 가까이 있냐에 따라 세포간의 경계를 정확히 분류하는 방법이다. 가중치가 클수록 해당 위치가 다른 셀과 근접해있기 때문에 경계에 점점 가까워진다고 할 수 있다. 아래 그림 Figure 3에서 b를 보면 여러 세포들이 경계를 이루고있는데 U-Net에서는 이 경계를 구분하기 위해 인접 세포와의 거리 정보를 가중치로 사용해 손실함수에 이를 반영한다. 본 논문에서는 $w_{0} = 10$, $\sigma$는 5에 근사하게 설정했다.
- 네트워크가 깊어짐에 따라 가중치를 초기화하는 방법이 매우 중요해졌다. 본 논문에서는 표준편차 $\sqrt 2/N$(N은 layers로 들어오는 노드 수)을 갖는 가우시안 분포를 활용했다. 예를 들어, 3x3x64 필터를 활용하면 N은 9x64=576이 된다.
2. Data Augmentation
- 많은 양의 데이터를 사용하지 않았기때문에 gray, 수평, 회전 변환 등과 같은 데이터 증강 기법이 필수적이었다.
- 추가로, 각 Grid마다 비선형적으로 변형을 주는 증강 기법인 elastic deformations을 활용했는데 smooth한 변환을 위해 3x3 격자 내에서 10 픽셀의 표준 편차를 가진 가우시안 분포의 난수를 샘플링하여 랜덤하게 이미지를 변환시켰다.
- 수축경로 끝에 있는 Drop-out layer에는 추가적인 다른 변환도 수행했다.
Experiments
- U-Net은 EM segmentation challenge에서 warping error 부분 가장 우수한 성적을 거두었다. Input data를 7개의 rotate 버전으로 증강하여 사용했고, 추가적인 전처리는 없었는데도 좋은 성적을 거두었다. 이는 sliding window 개념을 활용한 Ciresan(Group 3. IDSIA)보다 좋은 성적이다.
- 다른 대회의 이미지 데이터를 사용했을 때도 가장 좋은 성능(IoU Score 사용)을 기록했다. 위의 Figure 4는 ISBI 대회로 b,d에서 노란색 부분이 정답, 색칠된 부분이 U-net이 예측한 결과이다. 정답과 유사하게 segmentation을 수행한 것을 알 수 있다. PhC-U373, DIC-HeLa 데이터셋에서 모두 2위와의 격차를 크게 벌리며 좋은 성적을 낸 것을 알 수 있다.
Conclusion
- 적은 이미지를 사용했음에도 데이터 증강 기법을 잘 활용하여 좋은 성능을 기록했다. 학습 데이터의 크기가 크지 않아서 GPU도 효율적으로 사용할 수 있었다.
- U-Net 아키텍처가 다른 많은 태스크에도 사용될 것을 확신한다.
개인적인 생각
- U-Net이라는 기존에 없던 새로운 형태의 구조와 적은 입력 데이터로도 좋은 성능을 기록했다는 점에서 의미있는 연구였다.
- Contracting path에서 사용한 피처맵을 Expansive path에서도 사용하는 방법론이 다른 논문에도 영향을 끼치진 않았을까. ResNet에서 입력 데이터를 출력에 더해 잔차를 학습하는 방법론처럼.
- 세포간의 경계를 명확히 잘 구분하기 위해 거리를 기반으로 한 가중치를 손실함수에 사용했는데 다른 segmentation 논문에서는 어떤 방법을 활용할지 궁금하다.
- 의학 데이터를 기반으로 연구되었기때문에 도메인이 한정적인데 다른 도메인에서도 좋은 성능을 낼 수 있을까?차량의 스크래치를 찾아내는 프로젝트에서는 그리 좋은 성능을 기록하진 못했다.
- 작아진 이미지의 크기를 복원시키기위해 미러링 기법을 활용했는데 제로 패딩을 활용하는 것이 연산량의 측면에서 더 효율적이지 않을까라는 생각이 들었다. 데이터셋의 크기가 커진다면 두 가지 방법 중 어떤 방법이 더 효율적인 방법인지 비교해볼 수 있을 것이다.
구현
U-Net을 pytorch로 구현해보자.
import torch
import torch.nn as nn
def convolution_and_relu(in_channels, out_channles, kernel_size = 3, stride = 1, padding = 0, bias = False):
layers = []
layers += [nn.Conv2d(in_channels = in_channels, out_channels=out_channles, kernel_size=kernel_size, stride=stride,
padding=padding, bias=bias)]
layers += [nn.BatchNorm2d(num_features=out_channles)]
layers += [nn.ReLU()]
out = nn.Sequential(*layers)
return out
class UNet(nn.Module):
def __init__(self, in_channels, num_classes):
super(UNet, self).__init__()
# Contracting path
self.cont1_1 = convolution_and_relu(in_channels = in_channels, out_channles = 64) # 570 x 570 x 64
self.cont1_2 = convolution_and_relu(64, 64) # 568 x 568 x 64
self.pool1 = nn.MaxPool2d(kernel_size=2) # ((input-kernel) / stride) + 1 -> 284 x 284 x 64
# pool에서 stride를 지정해주지 않을시 stride = kernel size
self.cont2_1 = convolution_and_relu(64, 128) # 282 x 282 x 128
self.cont2_2 = convolution_and_relu(128, 128) # 280 x 280 x 128
self.pool2 = nn.MaxPool2d(kernel_size=2) # 140 x 140x 128
self.cont3_1 = convolution_and_relu(128, 256) # 138 x 138 x 256
self.cont3_2 = convolution_and_relu(256, 256) # 136 x 136 x 256
self.pool3 = nn.MaxPool2d(kernel_size=2) # 68 x 68 x 256
self.cont4_1 = convolution_and_relu(256, 512) # 66 x 66 x 512
self.cont4_2 = convolution_and_relu(512, 512) # 64 x 64 x 512
self.pool4 = nn.MaxPool2d(kernel_size=2) # 32 x 32 x 512
self.cont5_1 = convolution_and_relu(512, 1024) # 30 x 30 x 1024
self.cont5_2 = convolution_and_relu(1024, 1024) # 28 x 28 x 1024
# Expansive Path
# ConvTranspose2d를 사용해 n배의 크기를 갖는 결과를 얻고 싶으면 k=2n, s=n, p=(1/2)n
# U-Net에서는 padding이 0이니까 k=n, stirde=n
self.up4 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride = 2) # k + (i-1)*s - 2p -> 56 x 56 x 512
self.exp4_1 = convolution_and_relu(in_channels= 512 * 2, out_channles=512) # 54 x 54 x 512
self.exp4_2 = convolution_and_relu(in_channels= 512, out_channles=512) # 52 x 52 x 512
self.up3 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride = 2) # 104 x 104 x 256
self.exp3_1 = convolution_and_relu(in_channels= 512, out_channles=256) # 102 x 102 x 256
self.exp3_2 = convolution_and_relu(in_channels= 256, out_channles=256) # 100 x 100 x 256
self.up2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride = 2) # 200 x 200 x 128
self.exp2_1 = convolution_and_relu(in_channels= 256, out_channles=128) # 198 x 198 x 128
self.exp2_2 = convolution_and_relu(in_channels= 128, out_channles=128) # 196 x 196 x 128
self.up1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride = 2) # 392 x 392 x 64
self.exp1_1 = convolution_and_relu(in_channels= 128, out_channles=64) # 390 x 390 x 64
self.exp1_2 = convolution_and_relu(in_channels= 64, out_channles=64) # 388 x 388 x 64
# 이진 분류면 out_channels=1. num_classses를 1로 사용.
self.fc = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1) # 388 x 388
def forward(self, x):
x = self.cont1_1(x)
x = self.cont1_2(x)
x = self.pool1(x)
x = self.cont2_1(x)
x = self.cont2_2(x)
x = self.pool2(x)
x = self.cont3_1(x)
x = self.cont3_2(x)
x = self.pool3(x)
x = self.cont4_1(x)
x = self.cont4_2(x)
x = self.pool4(x)
x = self.cont5_1(x)
x = self.cont5_2(x)
x = self.up4(x)
x = self.exp4_1(x)
x = self.exp4_2(x)
x = self.up3(x)
x = self.exp3_1(x)
x = self.exp3_2(x)
x = self.up2(x)
x = self.exp2_1(x)
x = self.exp2_2(x)
x = self.up1(x)
x = self.exp1_1(x)
x = self.exp1_2(x)
x = self.fc(x)
return x