본문 바로가기

논문 리뷰

The fully convolutional transformer for medical image segmentation 리뷰

Tragakis, Athanasios, Chaitanya Kaul, Roderick Murray-Smith, and Dirk Husmeier. "The fully convolutional transformer for medical image segmentation." In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3660-3669. 2023. 

 

WACV 2023 Open Access Repository

The Fully Convolutional Transformer for Medical Image Segmentation Athanasios Tragakis, Chaitanya Kaul, Roderick Murray-Smith, Dirk Husmeier; Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2023, pp. 3660-3669 Abstr

openaccess.thecvf.com

 

Abstract

트랜스포머는 medical image analysis에 아직 초기단계이고 fine-grained segmentation을 위한 능력을 보유하고 있지 않다. Fully Convolutional Transformer(FCT)는 이미지 표현을 학습하기 위한 CNN의 장점과 long-term dependencies를 포착하기 위한 트랜스포머의 장점을 결합하였다. FCT는 medical imaging에서 첫 fully convolutional transformer이며 long-range semantic을 학습하기 위한 1단계와 계층적으로 전역정보를 학습하기 위한 2단계로 나누어 진다. FCT는 사전훈련없이도 우월한 성능을 보였다.

 

The Fully Convolutional Transformer

이미지 데이터 \(X\)와 그에 해당하는 semantic/binary segmentation maps \(Y\)로 이루어진 데이터셋 \(\{X, Y\}\)를 가정한다. 각각의 이미지는 \(x_i \in R^{H\times W\times C}\), \(C=\{3,...,N\}\)이고 \(y_i\in R^{H\times W\times K}\), \(K\in\{1,...,D\}\) 이다. FCT의 입력으로는 3D이미지의 조각들인 2D 패치이다. FCT는 기존 방식들과 다르게 over-lapping 패치로부터 특성을 추출하고 임베딩 후 Multi-head attention을 수행한다. 이후 Wide-Focus 모듈을 이용하여 이미지 내 세부 정보를 추출하여 projection한다.

FCT Layer

FCT는 가장 먼저 LayerNorm-Conv-Conv-MaxPool 의 순으로 시행된다(Convolutional Attention블록이 수행되기 전). Conv는 연속된 3x3 kernel size로 이미지의 특징을 추출하며 Gelu 활성화 함수가 뒤따른다. 다른 proposed block들과의 첫번째 차이점은 Convolutional Attention의 적용이다.

 

MaxPool이후에는 Depthwise-Convolution으로 토큰맵을 생성하는데 3x3 kernel size, \(s\times s\) stride, valid padding으로 입력과 출력의 크기를 고정시키고 다시 LayerNorm이 뒤따른다(그림에는 존재하지 않지만 코드로 확인해보면 Convolutional Projection, MHSA에서 중간중간 계속 LayerNorm이 존재함).

 

MHSA에서 기존 블록들과 또 다른 차이점이 나오는데 우리는 Linear projection이 아닌 Depthwise-Convolution을 사용하였다(위에랑 다른 Depthwise임). 이는 계산비용을 줄이면서 더 나은 spatial context를 추출할 수 있다.

 

마지막으로 Wide-Focus로 불리는 모듈은 medical image 속 fine-grained information을 추출하기 위한 모듈이다. 그림과 같이 총 3갈래의 conv layer로 나눠진 후 합쳐지는데 이는 2개의 dilated conv(dilation=2, 3)와 일반 conv의 출력을 합친 후 다시 추가적인 conv layer를 통해 특징을 병합한다.

 

Encoder

인코더는 총 4개의 FCT layer로 구성되어 있으며 각기 다른 크기에서 각 클래스들의 ROI feature를 획득하기 위해 피라미드 스타일의 multiple 이미지 input을 목표로 하지만 이같은 동작이 없어도 뛰어난 성능을 보였다.

 

Decoder

디코더는 bottleneck의 출력을 입력받아 binary/segmentation maps을 resample하도록 학습된다. 인코더와 똑같은 구조를 가지고 있으며 encoder의 출력과 skip-connection으로 연결되어 더해진다. (28x28)크기가 되는 가장 낮은 구간에서는 supervision을 사용하지 않았는데, 이는 ROI가 segmentation되기 너무 작고 모델의 성능을 저하시키기 때문이다.

 

Implementation Details

loss function : equally weighted combination of cross-entropy and dice loss

optimizer : Adam

learning rate : 1e-3

scheduler : plateau (thorugh monitoring the validation loss)

epochs : warm up 50 + training 250

data augmentation

- random rotation(0 ~ 360)

- zoom range (max 0.2)

- shear range (max 0.1)

- horizontal/vertical shift (max 0.3)

- horizontal/vertical flip

모든 모델은 사전훈련 가중치를 사용하지 않고 random initialzed 되었다.

 

자세한 실험결과는 논문을 참고하길 바람