Self-Supervised Learning of Pretext-Invariant Representations (PIRL)
05 Apr 2020 | SSL Contrastive--Learning FacebookIshan Misra, Laurens van der Maaten
(Submitted on 4 Dec 2019)
arXiv:1912.01991
내가 아는 현재 기준 SOTA인 SimCLR이 나오기 전의 SOTA인 PIRL 관련 논문이다. 여기서 PIRL은 진주인 pearl과 동일하게 발음하라고 논문에 친절히 적혀있다 ^^ㅋㅋ
지금까지 나온 pretext task들을 생각해보자. Rotation, Jigsaw, Affine 등 거의 대부분 원본 이미지를 변형 시키고 해당 변형의 속성을 찾는 것과 연관되어 있다. 즉, Rotation 같은 경우에는 이미지를 돌려놓고 몇 도가 돌아갔는지, rotation transformation의 속성을 찾아내는 task이다. 이 말인 즉슨, 현재까지의 pretext task들은 이미지 representation이 transformation과 covariant하도록 즉, 공변성을 가지도록 장려하고 있다.
하지만 여기서 저자들은 아주 당연하지만 그 누구도 쉽게 생각해내지 못한 점을 발견한다. 이는 원래 representation이라 함은 transformation에 invariant해야 한다는 것이다. Transformation은 우리 눈에 보이는 이미지들을 바꿀 뿐이지 그 본질적인 visual semantic 까지는 바꾸지 못하기 때문이다. 그래서 이들은 Pretext-Invariant Representation Learning (PIRL) 이란 획기적인 방법을 제시했다.
Figure 1: Pretext-Invariant Representation Learning (PIRL).
PIRL과 기존의 pretext learning과의 차이는 Figure 1.에서 볼 수 있다. 이들의 방법은 어느 pretext task에도 적용이 가능하다. 논문에서는 Jigsaw puzzle 문제를 사용하기로 했다. (Jigsaw puzzle 문제는 여기에 정리되어 있다.)
메인 아이디어는 다음과 같다. 원래 이미지와 transformation을 적용한 후의 이미지는 비슷해야하고 다른 이미지와는 또 달라야한다.
Loss function을 보자. 두 개의 이미지 representation들의 유사성을 측정하기 위해서 cosine similarity \(s(·,·)\)을 사용한다. Cosine similarity는 전에도 사용한 적이 있는데 이는 두 이미지가 비슷할수록 1에 가까워지고 다를수록 -1에 가까워진다. 이제 이 score를 noise contrastive estimator (NCE)에 적용하려한다.
우선 닮지않은 이미지와의 비교를 위해서 데이터셋 내의 현재 이미지와 다른 이미지들은 모두 negatives로 간주한다. 그러면 Positive sample (\(I, I^t\))는 \(N\)개의 negative 샘플들을 가진다. 이 negative 샘플들은 \(I\)가 아닌 다른 \(I'\)들을 변형시켜 얻어낸 feature들이다. Score을 사용한 NCE 모델은 다음과 같다:
분모의 오른쪽 항은 \(I\)가 아닌 \(I'\)의 representation과 transformed image \(I\)의 representaion과의 비교를 의미한다. 전체적으로 보면 수많은 데이터들 사이에 \(v_I\)와 \(v_{I^t}\)가 비슷한 확률을 의미한다. 이 식을 \(h(v_I, v_{I^t})\)라 정의한다.
실제로는 convolutional feature인 \(v\)를 바로 위 식에 적용하지는 않고 score를 계산하기 전 head를 통과한 feature들을 score안에 넣는다. 원본 이미지 \(I\)에 적용하는 헤드는 \(f(·)\)로 이를 통해 feature \(v_I\)가 계산되고 transform된 이미지 \(I^t\)에는 헤드 \(g(·)\)를 적용하여 feature \(v_{I^t}\)를 얻어낸다. 그림으로 정리하면 아래와 같다.
위에서 얻은 features를 가지고 NCE는 아래 loss를 최소화한다:
이 loss는 이미지 \(I\)의 representation과 transformed counterpart \(I^t\)의 representation이 비슷해지도록 반면에 다른 이미지들인 \(I'\)와는 달라지도록 장려한다. 위 식을 이해하려면 -log 함수의 그래프를 생각해보자. \(y = -log(x)\) 함수에서는 x가 커질수록 y의 값이 작아진다. 때문에 위의 loss를 줄이기 위해서는 log 안의 두 식이 커져야한다.
- \(h(f(v_I), g(v_{I^t}))\)가 커져야 한다 : h()는 cosine similarity score이므로 이 둘의 값이 커져야한다는 의미는 둘이 비슷해져야한다.
- \(1 - h(g(v_{I^t}), f(v_{I'}))\)가 커져야 한다 : 즉, \(h(g(v_{I^t}), f(v_{I'}))\)는 작아져야 한다. 고로 둘의 cosine similarity score가 작아져야 하므로 둘의 representation은 달라져야 한다.
더 좋은 representation을 얻기 위해서는 최대한 많은 negative들과 비교를 하며 학습을 하는 것이 좋다. 이 때 mini-batch SGD optimizer는 batch size를 크게 늘리지않고는 많은 negatives를 얻을 수 없다. 그러나 한번에 너무 큰 batch를 로딩하는 것은 자원의 한계에 다랄 수 있다. 저자들은 이 문제를 해결하기 위해 “cached” features가 담긴 memory bank를 사용하여 더 많은 negatives 이미지들을 이용했다. 처음 cache를 위해 데이터셋의 모든 원본 이미지들의 representation들을 memory bank에 cache 한다.
Memory bank \(M\)은 데이터셋 \(D\)안의 모든 이미지들에 대한 feature representation들을 저장하고 있다. 현재 에포크에 대한 representation \(m_I\)는 이전 에포크에서 계산된 feature representation들인 \(f(v_I)\)들의 exponential moving average의 결과이다. Exponential moving average는 최근의 값에 높은 가중치를 주지만 오래된 과거라도 비록 낮은 영향력이지만 가중치를 두어 함께 고려하도록 한 방법이다. 또한 여기서 중요한 것은 memory bank 안의 모든 representation들은 original image들로 부터 얻은 것이지 transformed image로부터 계산된 것이 아님을 알아야한다.
앞선 loss 식에서 우리는 아직 \(I\)와 \(I'\)의 비교를 하지 않았다. 이제 이 두 가지 비교를 고려하기 위해 저자들은 2개의 NCE loss function의 convex combination을 사용했다. 그리고 앞서 loss에서 사용했던 \(f(v_I)\)와 \(f(v_{I}')\) 는 memory bank representaion에 따라 \(m_I\)와 \(m_I'\)로 표현된다.
Convex combination이란?
\(\alpha_1 x_1 + \alpha_2 x_2 + ... + \alpha_n x_n\) 식에서 \(\sum_i^n{\alpha_i} = 1\) 이고 모든 \(i\)에 대해서 \(0 \le \alpha_i \le 1\) 인 combination을 말한다.
위의 식에서 \(m_I\)는 negatives도 포함한 전체적인 memory bank의 representation이라고 보면 된다. 우선 식의 첫번째 항을 보면 앞서 정의했던 eq (4).와 동일한 의미를 가진다. 두번째 항을 보면 이는 두 가지 일을 한다:
(1) representation \(f(v_I)\)가 memory representation \(m_I\)와 비슷해지도록 학습하고;
(2) representation \(f(v_I)\)와 \(f(v_I')\)가 달라지도록 학습한다.
즉, 앞의 항은 pretext invariant representation을 배우도록 하고 뒤의 항의 어느정도 covariant representation 을 학습하도록 한다. 여기서 만약 \(\lambda\)가 0이 되면 뒤의 항만 남게 되어 invariant representation을 배우지 않게 되는 NPID task가 되어버린다. 실험을 통해 저자들은 \(\lambda\)값이 0.5 일 때 가장 성능이 좋았다고 한다.
위의 loss를 그림으로 나타내면 위와 같다. 논문에서는 encoder network 로 ResNet-50 을 사용하고 있다. 또한 downstream task에 적용하기 위해 뒤의 head 부분을 떼고 인코더를 통해 나온 representation들을 사용한다.
Figure 2: ImageNet classification with linear models.
이미 SimCLR 이라는 SOTA가 존재하므로 실험내용은 따로 설명하지 않고 위의 표로 대체한다.
다 읽고..
현재는 SOTA가 아니긴하지만 그래도 아직까지도 충분히 놀라운 방법이다. 너무 당연하게 생각하면서도 누구도 실험하지 않았던 부분을 이용하여 당시 최고의 성능을 보였다니 놀랍다. 당연한 것을 당연하게 생각하지말고 조금 더 연구해봐야겠다는 생각이 든다. :)
Comments