본문 바로가기

논문 리뷰

The fully convolutional transformer for medical image segmentation 리뷰

Tragakis, Athanasios, Chaitanya Kaul, Roderick Murray-Smith, and Dirk Husmeier. "The fully convolutional transformer for medical image segmentation." In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3660-3669. 2023. 

 

WACV 2023 Open Access Repository

The Fully Convolutional Transformer for Medical Image Segmentation Athanasios Tragakis, Chaitanya Kaul, Roderick Murray-Smith, Dirk Husmeier; Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2023, pp. 3660-3669 Abstr

openaccess.thecvf.com

 

Abstract

트랜스포머는 medical image analysis에 아직 초기단계이고 fine-grained segmentation을 위한 능력을 보유하고 있지 않다. Fully Convolutional Transformer(FCT)는 이미지 표현을 학습하기 위한 CNN의 장점과 long-term dependencies를 포착하기 위한 트랜스포머의 장점을 결합하였다. FCT는 medical imaging에서 첫 fully convolutional transformer이며 long-range semantic을 학습하기 위한 1단계와 계층적으로 전역정보를 학습하기 위한 2단계로 나누어 진다. FCT는 사전훈련없이도 우월한 성능을 보였다.

 

The Fully Convolutional Transformer

이미지 데이터 X와 그에 해당하는 semantic/binary segmentation maps Y로 이루어진 데이터셋 {X,Y}를 가정한다. 각각의 이미지는 xiRH×W×C, C={3,...,N}이고 yiRH×W×K, K{1,...,D} 이다. FCT의 입력으로는 3D이미지의 조각들인 2D 패치이다. FCT는 기존 방식들과 다르게 over-lapping 패치로부터 특성을 추출하고 임베딩 후 Multi-head attention을 수행한다. 이후 Wide-Focus 모듈을 이용하여 이미지 내 세부 정보를 추출하여 projection한다.

FCT Layer

FCT는 가장 먼저 LayerNorm-Conv-Conv-MaxPool 의 순으로 시행된다(Convolutional Attention블록이 수행되기 전). Conv는 연속된 3x3 kernel size로 이미지의 특징을 추출하며 Gelu 활성화 함수가 뒤따른다. 다른 proposed block들과의 첫번째 차이점은 Convolutional Attention의 적용이다.

 

MaxPool이후에는 Depthwise-Convolution으로 토큰맵을 생성하는데 3x3 kernel size, s×s stride, valid padding으로 입력과 출력의 크기를 고정시키고 다시 LayerNorm이 뒤따른다(그림에는 존재하지 않지만 코드로 확인해보면 Convolutional Projection, MHSA에서 중간중간 계속 LayerNorm이 존재함).

 

MHSA에서 기존 블록들과 또 다른 차이점이 나오는데 우리는 Linear projection이 아닌 Depthwise-Convolution을 사용하였다(위에랑 다른 Depthwise임). 이는 계산비용을 줄이면서 더 나은 spatial context를 추출할 수 있다.

 

마지막으로 Wide-Focus로 불리는 모듈은 medical image 속 fine-grained information을 추출하기 위한 모듈이다. 그림과 같이 총 3갈래의 conv layer로 나눠진 후 합쳐지는데 이는 2개의 dilated conv(dilation=2, 3)와 일반 conv의 출력을 합친 후 다시 추가적인 conv layer를 통해 특징을 병합한다.

 

Encoder

인코더는 총 4개의 FCT layer로 구성되어 있으며 각기 다른 크기에서 각 클래스들의 ROI feature를 획득하기 위해 피라미드 스타일의 multiple 이미지 input을 목표로 하지만 이같은 동작이 없어도 뛰어난 성능을 보였다.

 

Decoder

디코더는 bottleneck의 출력을 입력받아 binary/segmentation maps을 resample하도록 학습된다. 인코더와 똑같은 구조를 가지고 있으며 encoder의 출력과 skip-connection으로 연결되어 더해진다. (28x28)크기가 되는 가장 낮은 구간에서는 supervision을 사용하지 않았는데, 이는 ROI가 segmentation되기 너무 작고 모델의 성능을 저하시키기 때문이다.

 

Implementation Details

loss function : equally weighted combination of cross-entropy and dice loss

optimizer : Adam

learning rate : 1e-3

scheduler : plateau (thorugh monitoring the validation loss)

epochs : warm up 50 + training 250

data augmentation

- random rotation(0 ~ 360)

- zoom range (max 0.2)

- shear range (max 0.1)

- horizontal/vertical shift (max 0.3)

- horizontal/vertical flip

모든 모델은 사전훈련 가중치를 사용하지 않고 random initialzed 되었다.

 

자세한 실험결과는 논문을 참고하길 바람