DINO v1 에서 사용한 방법 중 BYOL이라는 모델에 대한 논문 리뷰이다.
Abstract
Boostrap Your Own Latent은 자기주도 학습을 위한 방법으로 online, target이라는 2개의 네트워크가 상호작용하며 학습한다. 하나의 이미지를 서로 다르게 증강시킨 이미지에 대한 target의 출력을 online이 예측하도록 학습된다. 동시에 target 네트워크는 online 네트워크의 slow-moving average 값을 이용하여 학습된다.
Introduction
contrastive learning에서는 이미지 증강 기법에 따라 모델의 성능이 좌우되지만 BYOL은 그렇지 않다. 또한 target의 출력을 예측하도록 훈련된다면 모든 이미지에 대해 똑같은 출력을 내는 등 모델이 수렴할 수 없을 수도 있는데 저자들은 여러 경험을 통해 BYOL은 그렇지 않다는 것을 발견했다. 저자들은 Resnet 구조를 이용하여 모델을 설계했고 ImageNet 데이터셋을 이용해 모델을 학습시켯다.
Method
대부분의 contrastive method들은 하나의 이미지에 대해 여러 증강들과 비교하고, negative pair들과 비교하여 이미지의 표현들을 구별하는 것을 학습한다. 특히 증강 기법으로 cropping이 사용되는데 서로 다른 cropped 이미지들을 이용해 하나의 이미지를 예측하는 과정은 collapse representation으로 이어지기 쉽다. 많은 negative pair들과의 비교는 이미지 표현을 구별하기 어렵게 만드는데 우리는 negative pair가 collapse를 방지하기 위해 필수적인지 확인해 보았다.
BYOL은 online, target 네트워크로 구성되어 이미지 표현 \(y\theta\)를 학습하도록 훈련된다. online은 encoder \(f\theta\), projector \(g\theta\), predictor \(q\theta\) 로 이루어져 있다. target은 online을 학습시키기 위해 regression target을 제공한다. target을 학습시키기 위한 파라미터 \(\xi\)는 online weight \(\theta\)의 지수 이동 평균으로 (1)과 같다.
각 네트워크의 파라미터 \(\theta, \xi\)에 대한 손실함수는 (2)와 같다.
간단히 설명하면 normalized predictions \(q\theta(z\theta)\)과 target projections \(sg(z^\prime_\xi)\)간에 mean squared error을 구하는 공식이라고 한다.
BYOL은 negative pairs와 같은 collapse를 방지하기 위해 명백하게 정의된 항을 사용하지 않기 때문에 위 loss를 최소화 하는 방향으로 수렴할 것으로 생각된다. 하지만 target parameter \(\xi\)은 손실함수의 진행방향과 다르기 때문에 \(L_\theta,_\xi\) 은 없다고 가설을 제시한다.
수식들과 함께 더 많은 설명들이 있는데 사실 잘 이해가 되지 않는다. 이전에 리뷰했던 DINO v1 논문이 BYOL을 채택하였는데 해당 논문에서 해석한 바로는 online의 weight를 학습시키며 지수 이동평균을 target에 복사하는 식으로 훈련을 진행한다. 지수이동평균을 사용하지 않고 hard-copy 형식으로 weight를 복사하여도 어느정도의 훈련 효과가 있을것이라 했지만 이는 DINO의 실험결과에서 모델이 수렴하지 못하는 모습을 보여줬다.
Experiment evaluation and results
BYOL은 ImageNet ILSVRC-2012 dataset을 이용하여 자기주도학습의 형태로 사전훈련된 이후 전이학습을 통해 각 task에 맞게 미세 조정 되었다. 자세한 성능 결과는 논문을 참고하길 바라며 필자가 주의 깊게 본 내용만 하단에 담았다.
Batch Size : contrastive method 들은 작은 배치사이즈에서 적은 양의 negative pair을 사용하면 성능저하가 일어나지만 BYOL은 negative pair를 사용하지 않기 때문에 이에 대해 robust 하다고 한다.
Image augmentations : SimCLR에서는 color distortion 과 같은 증강 기법을 제거하면 모델의 성능에 매우 민감하게 반응하듯이 contrastive method들은 이미지 증강 기법의 사용이 중요하다. 또한 같은 이미지로부터 cropping은 color histogram을 공유하게 되는데 이는 이미지들마다 모두 다르기 때문에 contrastive method가 random crop만에 의존한다면 이는 color histogram에만 의존하게 되는것이다. 대신에 BYOL은 target representation을 online network로 주입시키기 때문에 예측 성능을 개선한다고 믿는다.
SimCLR에 비해 배치사이즈의 변화에 성능 저하가 거의 일어나지 않지만 256 -> 128 에서의 저하는 인코더의 BatchNorm 때문이라고 한다.
또한 BYOL은 color distortion(Remove color) 기법에 의한 영향이 적고 (BYOL : -9.1%, SimCLR : -22.2%) Crop만 사용하였을때도 그리 심각한 성능저하는 발생하지 않는다고 한다.
Bootstrapping : target network는 online network의 지수이동평균에 의한 가중치를 이용하기 때문에 delayed represent와 안정적인 가중치를 지닌다. target decay rate가 1일 경우 초기화된 가중치에서 업데이트가 일어나지 않고 상수의 형태로 유지된다. 반대로 0일 경우에는 online network와 똑같은 가중치로 매번 갱신된다.
0일 경우 학습이 전혀 일어나지 않으며 1일 경우 매우 낮은 성능을 보인다. 0.9 이상으로 300 epoch 학습했을 때 모두 68.4% 이상의 정확도를 보였다.