MLP-Mixer: An all-MLP Architecture for Vision

|


Ilya Tolstikhin∗, Neil Houlsby∗, Alexander Kolesnikov∗, Lucas Beyer∗, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy
[Google Brain]
[Submitted on 4 May 2021]
arXiv: https://arxiv.org/abs/2105.01601

Abstract

Convolution과 attention 메커니즘이 둘다 좋은 성능을 발휘하지만 반드시 필요한 것은 아니다.
이 논문에서는 MLP-Mixer라는 새로운 아키텍처를 이용하여 이 둘 없이도 비슷한 성능을 보이겠다.
MLP-Mixer는 token mixing과 channel mixing 두 가지 multi-layer perceptron을 이용한다.


1. Introduction & 2. Mixer Architecture

Mixer relies only on basic matrix multiplication routines, changes to data layout (reshapes and transpositions), and scalar non-linearities.

현대 deep vision 아키텍처들은 다음과 같이 feature들을 믹스하는 layer들로 구성되어 있음.

(1) at a given spatial location

(2) between different spatial locations

or both at once.

CNN의 경우,

  • pooling layer를 통해 (2)를 수행.
  • 1x1 pooling layer 로 (1) 수행
  • NxN conv (N>1) 로 (1), (2) 수행

Vision Transformer나 다른 attention-based 아키텍처의 경우,

  • self-attention layer가 (1), (2) 모두 수행
  • MLP-block이 (1)을 수행.

→ Mixer 아키텍처는 이 두 가지 연산을 완전히 분리시켜서 수행한다. Untitled 1

Mixer는 S개의 이미지 패치들을 input sequence로 받고 각각의 패치들은 hidden dim C를 가짐.

Untitled 2

원래 이미지가 (H, W) 사이즈였다면 각각의 패치크기는 (P, P)이며 패치 개수는 $S = HW/P^2$ 가 됨.


Mixer layer equation

Untitled 3

  • Token Mixing : 다른 spatial location (tokens) 간 communication을 함. 각 채널에 대해 독립적으로 수행되며 테이블의 각 column을 input으로 받음.

    Untitled 4

  • Channel Mixing : 다른 채널 간 communication을 함. 각의 토큰에 대해 독립적으로 수행되며 input으로 테이블의 각 row 를 받음.

    Untitled 5

각각의 MLP block들은 2개의 fully-connected layer와 non-linearity (GELU)가 포함된다.

  • $D_S$, $D_C$는 token, channel mixing 각각의 hidden width.
  • $D_S$는 input 패치들의 개수와 독립적으로 선택이 가능.

    → ViT는 패치들의 개수와 관계가 있기 때문에 연산량이 quadratic인 반면 얘는 input 패치들의 개수에 linear하다.

  • $D_C$도 패치 사이즈와 관계없이 독립적으로 선택 가능. (as for a typical CNN)

Tying Parameter (weight sharing)

Channel-mixnig MLP의 파라미터들을 tying하는 것은 convolution의 특징인 positional invariance를 제공하는 것 처럼 natural하다고 볼 수 있다 (= conv filter들이 sliding 하면서 같은 weight를 계속 유지하는 것을 의미).

그러나 token-mixing에서의 tying은 much less common하다. (얘는 밑에서 depth-wise conv 다룰 때 설명할 예정)

→ 결국 sharing을 하면서 상당한 메모리를 아낄 수 있다.

  • Channel mixing : 1 x 1 convolution (한 row를 받아 (하나의 토큰) 다시 하나의 row로 내뱉음)
  • Token mixing : full receptive field를 가진 single-channel depth-wise convolution

MLP-Mixer-1

Mixer는 MLP layer 외에도 skip-connection이나 layer normalization을 사용함.

그리고 ViT와 달리 position embedding을 사용하지 않음. 아래 그림에서 순서대로보면은 그냥 패치 순서들을 알 수 있기 때문.

Untitled 7

마지막에는 GAP layer에 넣고 linear classifier로 classification을 수행.


3. Experiments

관심요소

  1. downstream task에서의 성능
  2. pre-training에서의 total computational cost
  3. inference time에서의 처리량.

Downstream task

Untitled 8

Pre-training data

Untitled 9

Pre-traning details

Untitled 11

Fine-tuning details

패치 사이즈는 고정시키기 때문에 input 이미지 크기가 커지면 패치의 개수가 커지게 됨 ($S -> S’)$

weight를 fine-tuning에 사용하기 위해서 hidden layer width도 패치 개수에 비례하여 증가시키고

($D_S -> D_S’$) 더 커진 weight matrix $W_2’∈ R^{D_S’×S’}$ 을 초기화할 때 기존의 W2를 diagonal에 붙여버림. (??)

Untitled 12

Untitled 13

Metrics

  1. Computational cost
    1. Total pre-training time on TPU-v3 accelerators
    2. Throughput in images/sec/core on TPU-v3

여러개의 배치 사이즈로 실험. {32, 64, …, 8192}

  1. Model Quality

: we focus on top-1 downstream accuracy after fine-tuning

Models

Untitled 14

Figure&Table annotation

  • MLP Mixer
  • CNN based
  • Attention based

HaloNet : 3x3 conv대신에 local self-attention layer를 사용한 ResNet-like structure를 가진 attention based model임. 이 논문에서는 conv랑 attention이 섞인 hybrid 버전인 “HaloNet-H4 (base 128, Conv-12)” model을 비교함.

Big Transfer (BiT) models are ResNets optimized for transfer learning, pre-trained on ImageNet-21k or JFT-300M.

NFNets are normalizer-free ResNets with several optimizations for ImageNet classification. We consider the NFNet-F4+ model variant.


MPL and ALIGN for EfficientNet architectures.

MPL is pre-trained at very large-scale on JFT-300M images, using meta-pseudo labeling from ImageNet instead of the original labels. We compare to the EfficientNet-B6-Widemodel variant.

ALIGN pre-train image encoder and language encoder on noisy web image text pairs in a contrastive way. We compare to their best EfficientNet-L2 image encoder.


3.1 Main results

Untitled 15

Untitled 16 Untitled 17 Untitled 18

→ training time에 대해 순차적으로 best 들을 점선으로 그어놨음.

3.2 The role of the model scale

더 작은 믹서 모델도 실험해보자!

2가지 서로 다른 방식으로 모델을 scale할 수 있음.

  1. pre-training할 때 model size를 늘리자 (# layers, hidden dimension, MLP widths)
  2. fine-tuning할 때 input image resolution을 키우자.

1의 경우 pre-training compute와 test-time throughput 둘 다에 영향을 미침.

2의 경우는 오직 throughput에만 영향을 미침.

(따로 언급하지 않으면 fine-tuning resolution은 224를 사용)

Untitled 19

  • IN을 scratch로 학습시킬 때, Mixer-B/16 모델이 top-1 acc 76.44%를 달성. (ViT-B/16보다 3프로 못함)
  • 논문에는 없지만 얘들이 주장하는 바로는 training curve가 두 모델 모두 매우 비슷한 value를 가진 training loss를 보였다 함. 즉, Mixer-B/16이 ViT-B/16보다 오버핏이 더 일어났다는 얘기.

  • pre-training

Untitled 20

Untitled 21


3.3 The role of the pre-training dataset size

pre-traning에 사용되는 데이터가 더 큰 데이터셋일수록 mixer의 성능은 더 높아짐을 바로 위에서 확인.

이를 더 자세히 보고자, pre-training에 사용되는 이미지의 양을 subset으로 조절하여 실험.

pretrain dataset은 JFT-300M을 사용했고 3%, 10%, 30%, 100% 서브셋을 활용.

모든 모델의 전체 step의 수를 동일하게 유지하기 위해 각 서브셋에 대해 233, 70, 23, 7 에폭씩 학습시킴.

모델은 Mixer-B/32, Mixer-L/32, and Mixer-L/16 models 가 사용됨.

linear 5-shot top-1 accuracy on ImageNet 실험. (early stopping 사용 → 근데 그러면 step수가 의미가 있나)

Untitled 22

JFT300M의 아주 작은 서브셋에 대해 학습할때는 모든 mixer모델들이 오버피팅됨. 대신에 더 큰 데이터셋을 사용할 때 mixer는 ViT보다도 더 효과를 봤음 (성능 향상 그래프의 기울기가 더 가파름)

효과를 본 이유에 대해서 저자들은 대충 끼워맞추기 설명을 하긴 함. (?????)

Untitled 23

3.4 Visualization

CNN의 첫번째 layer는 이미지의 local region들에 대한 Gabor-like detector를 학습하는 것은 자명하다. 반면 Mixer는 token-mixing을 통해 global한 정보 교환이 일어나는데 이 또한 어떤 정보를 포함하는지 궁금하여 실험을 진행하였다.

그림은 JFT-300M에 학습된 MIxer의 첫번째 몇개의 token-mixing MLP들의 weight를 보인것이다. Token-mixing은 다른 spatial location사이의 communication을 허락하기 때문에 학습된 feature들 중 일부는 전체 이미지에 대해 학습했을 수도 있고 아니면 작은 지역에 대해서만 학습이 되었을 수 있다.

Untitled 24

아마 위처럼 적힌 이유는 뒤의 레이어일수록 서로 많은 token들에 대해 communication을 했기 때문에 더 많은 정보가 섞이게 되어 명확하게 식별이 불가한 구조를 가지는 것 같다.

Untitled 25

일단 나는 아래 그림처럼 하나의 유닛에 대해 들어오는 196개의 weight들을 visualization했다고 생각하는데 저렇게 정사각형으로 뭉친 기준은 모르겠음. 순서대로 뭉친건가.. 글고 이렇게 visualization한게 맞는지도 잘 모르겠음 ㅎ 근데 또 fc dim은 196개가 맞나..?

Untitled 26


5. Conclusion 및 내 생각

Convolution이나 attention 없이 MLP만 사용하여 SOTA는 아니더라도 그에 견줄만한 성능을 보이는 모델을 만들어냈으며 accuracy와 computational resources의 trade-off 관점에서 굉장히 좋은 성능을 보임.

그러나 결과적으로 좋은 성능을 보이긴 했지만 왜 이렇게 잘 굴러가는지에 대해 명확하게 설명되어 있지 않음.. 왜 이 실험을 진행했는지가 없어서 논문자제가 조금 결과론적이지 않나 생각하게 됨.

주변에 의견을 물어본 결과, 패치단위로 잘랐기 때문에 잘되었을 것이라는 말도 있었음.

Untitled 27

Comments