A Simple Framework for Contrastive Learning of Visual Representations (SimCLR)
17 Mar 2020 | SSL Contrastive--Learning GoogleTing Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton (Google research)
(Submitted on 13 Feb 2020)
arXiv:2002.05709
SimCLR:a simple framework for contrastive learning 을 사용한 논문이다.
1. Introduction
인간의 지도 없이 visual representation을 배우는 것은 long-standing problem이다. 대부분의 주류 방식들은 두 가지 중 하나이다: generative 또는 discriminative. Generative 방식은 input space 속 pixel을 생성하거나 모델링하는 것을 배운다. Pixel-level generation은 비싼 계산을 필요로하며 representation learning에 필수적이진 않다. Disciminative 방식은 objective function을 사용하여 representation을 배운다. 또한 input과 label들은 unlabeled dataset에서 나온다. Latent space 속 contrastive learning을 베이스로 하는 discriminative 방식 최근들어 전망이 좋고 SOTA 결과를 보여줬다.
Figure 1. ImageNet top-1 accuracy of linear classifiers
이 논문에서 이들은 SimCLR이라고 불리는 visual presentation의 contrastive learning을 사용한 간단한 프레임워크를 선보인다. Figure 1.에서 볼 수 있듯이 SimCLR은 이전의 성능을 능가할 뿐만 아니라 더욱 심플한 구조를 가진다.
2. Method
2.1. The Contrastive Learning Framework
SimCLR은 latent space 속 contrastive loss를 통해 같은 데이터 이미지의 다른 augmentation이 적용된 이미지들 사이의 agreement를 최대화함으로써 representation을 배운다. (말을 너무 어렵게 설명해놔서 처음에는 무슨 뜻인지 몰랐다. 뒤에 설명을 보면서 차차 이해를 하자.)
Figure 2. A simple framework for contrastive learning of visual representations.
Figure 2.를 보면 프레임워크는 4가지의 주요 부분으로 구성되어 있다.
-
같은 이미지 샘플을 다른 두 가지 버전으로 변형시키는 stochastic data augmentation module이 존재한다. 이 때 두 가지 이미지는 \(\tilde{x}_i\)와 \(\tilde{x}_j\)로 표현한다. 그리고 이 둘을 positive pair라고 생각한다. 이 논문에서는 총 3가지의 간단한 augmentation을 적용하는데 random cropping, random color distortions, random Gaussian blur이다.
-
Neural network base encoder \(f(·)\)이 존재한다. 이는 augmented data example들로부터 representation vector를 뽑아낸다. 이들 framework는 제약조건 없이 아무 architecture를 사용해도 되나 해당 논문에서는 ResNet을 사용하였다. 즉, \(h_i = f(\tilde{x}_i) = ResNet(\tilde{x}_i)\), 여기서 \(h_i ∈ ℝ^d\)가 average pooling layer 이후에 나오는 output이다.
-
작은 neural network projection head \(g(·)\)이 존재한다. 이 네트워크는 representation들을 contrastive loss가 적용된 space로 매핑한다. \(z_i = g(h_i) = W^{(2)}𝝈(W^{(1)}h_i)\)를 얻기위해 1개의 은닉층이 있는 MLP를 사용한다. 이 때 \(𝝈\)는 ReLU nonlinearity이다.
-
Contrastive prediction task를 위해 정의된 contrastive loss function이 존재한다. Positive pair인 \(\tilde{x}_i\)와 \(\tilde{x}_j\)를 포함한 set \(\{\tilde{x}_k\}\) 가 주어졌을 때, contrastive prediction task는 주어진 \(\tilde{x}_i\)에 대하여 \(\{\tilde{x}_k\}_{k≠i}\) 에서 \(\tilde{x}_j\)를 식별하는 것을 목표로 한다.
이들은 랜덤하게 minibatch를 \(N\)개의 examples로 샘플했다. 그리고 미니배치로부터 파생된 augmented examples의 쌍에 대해 contrastive prediction task를 정의했다 고로 \(2N\) data points가 생성됐다. (대략 2N개 이미지에 대해서 학습을 했다는 뜻인 듯)
대신 따로 negative example에 대해서 샘플링을 하진 않았다. 대신에 주어진 positive 쌍이 있을 때 미니배치 내 나머지 \(2(N-1)\)개의 augmented examples를 negative example로 삼았다. \(sim(u,v) = u^𝗧v/\lvert\lvert u \rvert\rvert \lvert\lvert v \rvert\rvert\)은 두 벡터 \(u\)와 \(v\) 사이의 cosine similarity를 의미한다. 그러면 positive pair \((i,j)\)의 loss function은 아래와 같이 정의된다.
이 때 \(𝟙_{[k=1]} ∈ \{0,1\}\)은 \(k ≠ i\) 인 경우에만 1을 반환하는 indicator function이다. 또한 \(τ\)는 temperature parameter를 나타낸다. Loss는 모든 positive paris에서 계산이 되며 이 loss를 NT-Xent(the normalized temperature-scaled cross entropy loss)라고 부르기로 한다.
위 알고리즘은 제시된 방식을 요약한 것이다.
2.2. Training with Large Batch Size
저자들은 memory bank를 사용하지 않았고 대신 트레이닝 배치 사이즈 \(N\)을 256에서부터 8,192까지 다양하게 적용을 시켰다. 8,192의 배치 사이즈는 positive pair마다 16,382개의 negative examples를 제공한다. 그러나 큰 배치 사이즈 학습은 linear learning rate scaling을 사용하는 표준 Momentum/SGD를 사용할 때 불안정할 수 가 있다. 이를 안정화 시키기 위해 이들은 LARS optimizer를 적용했다.
Global BN. Data parallelism을 사용하는 분산 학습에서는 BN mean과 variance가 디바이스마다 로컬로 집계가 된다. 논문의 constrastive learning에서 같은 디바이스 속에서 positive pairs가 계산될 때마다 모델은 representation을 개선시키지 않고 prediction 정확도만 높이기 위해서 로컬 누출 정보를 이용할 수도 있다. 이는 곧 trivial 하게 될 우려가 있기 때문에 저자들은 학습 동안에 BN mean과 variance를 전체 디바이스를 아울러 집계를 하도록 하여 이 문제를 피해갔다.
2.3. Evaluation Protocol
Dataset and Metircs. 논문의 대부분 실험은 ImageNet ILSVRC-2012 데이터셋에 행해졌고 몇몇은 CIFAR-10에도 적용이 됐다. 학습된 representation을 평가하기 위해 널리 사용되고 있는 linear evaluation protocol을 사용했다. 이 프로토콜은 하나의 linear classifier를 base network의 꼭대기에 학습시키며 여기서 나온 test accuracy가 representation의 질을 평가하는 proxy로 사용이 된다.
3. Data Augmentation for Contrastive Representation Learning
Data augmentation defines predictive tasks. Data augmentation이 지도나 비지도학습에 많이 사용은 됐지만 contrastive prediction task를 정의하는 체계적인 방식으로는 이용된 적이 없다. 현재까지 존재하던 논문들에서 architecture를 바꿈으로써 contrastive prediction task를 정의했다.
예를 들어, Hjelm et al. (2018)나 Bachman et al. (2019)에서는 network architecutre안의 receptive field 제한을 통해서 global-to-local view prediction을 달성했다. 반면에 Oordet al. (2018)나 Hénaff et al. (2019)에서는 fixed iamge splitting procedure와 context aggregation network를 통해 neighboring view prediction을 성취했다. (사실 다 안 읽어봐서 모르겠는데 어쨌거나 그냥 network architecutre를 바꿔 사용했다는 뜻인듯..) 아래 Figure 3.에서 위의 두 가지를 볼 수 있다.
Figure 3. Solid rectangles are images, dashed rectangles are random crops.
위의 논문들은 복잡한 과정을 통해 이뤄낸 것을 이 논문의 저자들은 간단한 random cropping(with resizing)을 통해서 위의 두 가지 작업을 포함하는 이미지를 생성하여 사용할 수 있었다.
3.1. Composition of data augmentation operations is crucial for learning good representations
Figure 4. Illustrations of the studied data augmentation operators.
Figure 4.는 이 논문에서 연구된 augmentation을 보여준다. 각각의 augmentation의 효과와 augmentation composition의 중요성을 이해하기 위해 augmentation을 적용했을 때의 프레임워크 성능을 조사했다. ImageNet 이미지는 사이즈가 다 다르기 때문에 저자들은 crop과 resize를 항상 적용했다. 대신 이로 인해서 cropping이 없는 상태에서의 augmentation을 연구하는 건 어려웠기 때문에 이런 이중효과를 제거하기 위해서 asymmetric data transformation setting을 고려하였다.
구체적으로 말하면, 처음에 랜덤하게 이미지를 crop하고 다 같은 크기로 resize를 한다. 그리고 나서 타겟 augmentation(s)을 Figure 2. 속 framework의 오직 단 하나의 branch에만 적용을 시킨다. 그리고 나머지 하나는 그냥 identity로 냅둔다 즉, \(t(x_i) = x_i\). 이러한 asymmetric data augmentation은 성능을 해할 수 있다.
Figure 5. Linear evaluation (ImageNet top-1 accuracy) under individual
or composition of data augmentations, applied only to one branch.
Figure 5.는 transformation들을 적용했을 때 linear evaluation 결과를 보여준다. Contrastive task에서 positive pairs를 완벽하게 찾아내는 모델이더라도 single transformation은 좋은 representation을 만들지 못했다. Composed augmentations의 경우에는 contrastive prediction task가 더 어려워졌지만 representation의 퀄리티는 극적으로 향상됐다.
그리고 하나의 composition이 특출났는데 random cropping과 random color distortion이 합해졌을 때이다. 저자들은 오직 random cropping만을 적용했을 때 이미지 대부분의 패치들이 비슷한 color distribution을 공유한다는 심각한 문제가 있다고 추론해냈다.
Figure 6. Histograms of pixel intensities (over all channels) for different crops of two different images (i.e. two rows).
Figure 6.는 color histogram만으로도 이미지들을 구별할 수 있다는 것을 보여준다. Neural nets은 이러한 shortcut을 사용해서 predictive task를 풀지도 모른다. 그렇기 때문에 일반화될 수 있는 feature들을 배우기 위해서 color distortion을 cropping과 함께 쓰는 것이 중요하다.
3.2. Contrastive learning needs stronger data augmentation than supervised learning
Table 1. Top-1 accuracy of unsupervised ResNet-50 using linear evaluation and supervised ResNet-50
Color augmentation의 중요성을 입증하기 위해서 이들은 Table 1.에 보이는 것과 같이 color augmentation의 강도를 조정했다. 더욱 강한 color augmentation일수록 학습된 unsupervised model의 linear evaluation 성능을 향상시켰다. 같은 augmentations의 세트로 supervised model에 적용해봤을 때 더 강한 color augmentation은 개선하지 못했고 오히려 성능을 해하는 것도 알아냈다. 그러므로 그들의 실험은 unsupervised contrastive learning 은 supervised learning보다 더욱 강한 (color) data augmentation으로부터 더 큰 이득을 얻을 수 있다는 것을 보였다.
4. Architectures for Encoder and Head
4.1. Unsupervised contrastive learning benefits (more) from bigger models
Figure 7. Linear evaluation of models with varied depth and width.
놀랍지는 않겠지만 Figure 7.은 depth와 width의 증가가 모두 성능을 개선시키는 걸 보여준다. 모델 사이즈가 커짐에 따라 supervised 모델과의 갭이 줄어드는 것을 확인했다.
4.2. A nonlinear projection head improves the representation quality of the layer before it
Figure 8. Linear evaluation of representations with
different projection heads \(g(·)\)and various dimensions of \(z = g(h)\).
이번에는 Projection head \(g(h)\)를 포함하는 것의 중요성에 대해 연구를 했다. Figure 8.은 head를 위해 3가지의 다른 architecture를 사용했을 때의 linear evaluation을 보여준다.
- identity mapping (= None)
- linear projection
- default nonlinear projection with one additional hidden layer (and ReLU)
Nonlinear projection이 linear projection보다 (+3%) 성능이 좋았고 projection이 없는 것 보다 (>10%) 훨씬 더 좋았다. Projection head가 사용됐을 때, output dimension과 관계없이 유사한 결과가 관찰되었다. 또한 nonlinear projection을 사용하더라도 projection head 이전의 layer, \(h\)는 레이어 이후인 \(z = g(h)\)보다 훨씬 더 좋았다 (>10%). 즉, projection head 이전의 hidden layer가 head 이후 보다 더 좋은 representa-tion임을 의미한다.
이들은 nonlinear projection 전의 representation 을 사용하는 것이 중요한 이유는 contrastive loss에 의해 유도된 정보의 손실 때문이라고 추측한다. 특히, \(z = g(h)\)은 data transformation에 영향을 받지 않도록 학습이 된다. 그러므로 g는 색이나 오브젝트의 방향과 같은 downstream task에 쓸모있을 수도 있는 정보를 지울 수도 있다. nonlinear transformation \(g(·)\)를 이용함으로써 \(h\)에 더 많은 정보가 형성되고 유지될 수 있다.
이 가정을 입증하기 위해서 저자들은 pretraining동안에 적용되는 transformation을 예측하도록 학습하기 위해 \(h\)또는 \(g(h)\)를 사용했다. 여기서 \(g(h) = W^{(2)}σ(W^{(1)}h)\)이고 input과 output의 차원이 동일(2048)하다. Table 3.는 \(g(h)\)가 정보를 잃는 반면에 \(h\)는 적용된 transformation에 대한 훨씬 더 많은 정보를 포함하고 있다는 것을 보여준다.
Table 3. Accuracy of training additional MLPs on different representations to predict the transformation applied.
5. Loss Functions and Batch Size
5.1. Normalized cross entropy loss with adjustable temperature works better than alternatives
저자들은 logistic loss나 margin loss 같은 흔하게 사용되는 contrastive loss function과 NT-Xent loss를 비교해보았다.
Table 2. Negative loss functions and their gradients. All input vectors, i.e. \(u, v^+, v^−\) are \(l_2\) normalized.
Table 2.는 loss function의 input에 대한 gradient 뿐만 아니라 objective function도 보여준다. Gradient를 보면 다음과 같은 것을 관찰할 수 있다.
- Temperature를 따르는 \(l_2\) normalization은 모델이 어려운 negatives로부터 학습하는 것을 도울 수 있다.
- Cross-entropy와는 다르게 다른 objective function들은 그들의 상대적인 난이도에 의해 negatives에 무게를 주지 않는다.
그 결과로 이러한 loss function들에 대해서 semi-hard negative mining을 적용해줘야만 한다: gradient를 모든 loss term에 대해서 계산하는 대신에 하나는 semi-hard negative terms를 이용해서 gradient를 구할 수 있다. (즉, loss margin 내에 있고 거리가 가까운 그러나 positive examples과는 먼)
비교를 공평하게 하기 위해서 모든 loss function에 대해 다 같은 \(l_2\) normalization을 적용했고 hyperparameter들을 조정하여 그들 각각의 최고 결과를 뽑아냈다. 아래 Table 4.에서 (semi-hard) negative mining이 도움을 줬지만 그래도 best는 여전히 이들의 기본 NT-Xent loss보다 못했다.
Table 4. Linear evaluation (top-1) for models trained with different loss functions.
“sh” means using semi-hard negative mining.
그 다음 저자들은 그들의 기본 NT-Xent loss의 \(l_2\) normalization과 temperature \(τ\)의 중요성을 테스트했다. Table 5.는 normalization과 적절한 temperature 스케일링이 없이는 성능이 매우 나쁘다는 것을 보여준다. \(l_2\) normalization 없이는 contrastive task의 정확도 높았지만 그 결과 representation은 linear evaluation 하에서 성능이 좋지 않았다.
Table 5. Linear evaluation for models trained with different choices of \(l_2\) norm and temperature \(τ\) for NT-Xent loss.
5.2. Contrastive learning benefits (more) from larger batch sizes and longer training
Figure 9. Linear evaluation models (ResNet-50) trained with different batch size and epochs.
Each bar is a single run from scratch.
Figure 9.은 모델이 에포크의 수가 다른 경우 학습될 때 batch size의 효과를 보여준다. 더 많은 트레이닝 스텝/에포크를 사용하면 배치가 무작위로 리샘플링되는 경우 다른 배치 크기 사이의 차이가 줄어들거나 사라집니다. Supervised와는 대조적이게 contrastive learning에서는 더 큰 배치 사이즈가 더 많은 negative examples을 만들어 convergence를 촉진시킨다. 즉, 동일한 accuracy에 대해 더 적은 에포크와 스텝이 필요로 한다. 오래 학습시키는 것 또한 더 많은 negative examples을 만들어 결과를 개선시킨다.
6. Comparison with State-of-the-art
다른 논문들과 비슷하게 여기서도 3가지 다른 hidden layer width를 가진 (multipliers of 1×,1×,4×) ResNet-50를 사용했다. 더 빠른 수렴을 위해 여기서는 1000 에포크 동안 학습을 시켰다.
Linear evaluation.
Table 6. ImageNet accuracies of linear classifiers trained on representations
learned with different self-supervised methods.
Table 6.은 linear evaluation setting에서의 결과를 비교합니다. 구체적으로 디자인된 아키텍처를 필요로하는 이전의 방법들에 비해 상당히 더 나은 결과를 얻기 위해 standard network를 사용할 수 있다. 그들의 ResNet-50(4×)로 얻을 수 있는 최고 결과는 supervised pretrained ResNet-50와 일치할 수 있다.
### Semi-supervised learning.
Table 7. ImageNet accuracy of models trained with few labels
저자들은 레이블이 있는 ILSVRC-12 트레이닝 데이터셋의 1% 또는 10%를 class-balanced 방식으로 샘플링했다 (각 클래스마다 약 12.8 또는 128개의 이미지).그 다음 regularization 없이 labeled 데이터에 전체 base network를 fine-tune을 시켰다. Table 7.은 최근의 논문들과의 비교를 보여준다. 역시나 이들의 방식이 1% 나 10% 레이블 방식에서 모두 SOTA 성능을 넘어섰다.
### Transfer learning.
Linear eavaluation(fixed feature extractor)와 fine-tuning 세팅에서 모두 12개의 natural image 데이터셋에 대해 transfer learning 성능을 평가하였다.
Table 8. Comparison of transfer learning performance of our self-supervised approach with supervised baselines
Table 8.은 ResNet-50(4×) 모델의 결과를 보여준다. Fine-tune했을 때, 이들의 self-supervised model은 5개의 데이터셋에 대해 supervised baseline을 훌쩍 넘어섰다. 반면에 supervised baseline은 오직 2가지 데이터셋 (Pets and Flowers)에서만 우월했다. 나머지 5개의 데이터셋은 비슷비슷했다.
다 읽고..
이 논문에서는 그냥 data augmentation을 사용하고 네트워크의 끝에 nonlinear head를 사용하고 loss function을 적용한 것만으로도 색다른 ssl task를 만들어냈다.
Comments