본문 바로가기
Deep Learning (Computer Vision)/Vision Transformer Architecture

논문 톺아보기 및 코드 구현 [ViT-1] - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (ICLR 2021)

by 187cm 2023. 9. 23.
반응형

 필자가 이 논문을 처음 봤을 때가 2022.11.14일인데, 이 때 기준 인용 수가 8829이다. 현재 2023.09.22기준 21446회.. 그 당시에도 한달마다 몇백회씩 인용수가 늘어나는 것이 놀라웠는데, 이젠 그 이상으로 유명해진 것 같다. 그리고 미루고 미루다 DeiT, DeiT-III, FlexiViT 등 다양한 ViT를 발표해야할 것만 같아서 미뤄왔던 정리를 해야만 할 것 같다. 

 그 당시에 Image classification 분야에서 가장 좋은 성능을 보여주는 모델은 무엇일까? 라는 질문에서 시작하여, 가장 좋은 성능을 보여주는 transformer 계열의 모델에 대해 궁금증이 생겼고, Computer vision 분야의 기본인 transformer가 적용된 VIT에 대해 정리하는 것을 목표로 하고 읽었었다.

 Vision Transformer NLP 분야에서의 표준 매뉴얼로 자리 잡은 Transformer에 대해서 설명한 후, 이 Transformer 구조를 Computer vision 분야에 가져오기 위해 나왔던 Related work에 대해 설명을 하며 Introduction을 시작하도록 하겠다.


제목 :  An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

학회 : ICLR2021

링크 : https://arxiv.org/abs/2010.11929

저자 : Alexey Dosovitskiy, Lucas Beyer, et al. 

소속 : Google Research, Brain Team

인용 : 21446


0. Abstract

- ViT의 구조는 Original Transformer와 매우 유사하다!

- Unlike the original Transformer, VIT used Norm layer before MHA/MLP layer

- CLS token added to Output of Positional Encoding (BERT와 유사)

-Input image divide 16x16 resolution patches (Hidden size D is 16x16x3)


1-1. Related Work (Transformer) 

- Transformer기존의 NLP task에서 CNN/RNN 구조의 Network가 가지는 여러가지 문제를 해결.

1. 고정된 길이의 입출력 문제를 해결하기 위해 Encoder/Decoder구조의 Sequence to Sequence 등장.

2. 단어의 입출력 길이가 길어질수록 단어들 간의 의존도가 떨어지는 문제를 해결하기 위해 RNN 위에 Attention을 사용.

- NLP task를 해결하기 위한 다양한 방법이 나왔지만, RNN순차적인 구조를 해결하기 위한 해결책은 제시되지 못함.

- Attention is All you need라는 논문에서 소개된 Transformer는 Encoder/Decoder 구조를 유지하며 이러한 문제를 해결하였을 뿐 아니라 성능 또한 더욱 뛰어난 모습을 보여주었기에 의미가 크다.

가장먼저 왼쪽의 Transformer architecture, 중앙이 Multi-Head Attention 구조, 마지막이 Scaled dot product attention 연산으로 MHA 블록 내부에 존재.

-  위에서 왼쪽 사진은 Transformer의 구조. 왼쪽이 Encoder, 오른쪽이 DecoderEncoder/Decoder 구조 (Seq2Seq)

- NLP Task에서 입력 문장이 들어오면 입력 문장을 token 단위로 쪼개 Input Embedding을 진행하고, 토큰마다 Positional Encoding을 통해 단어의 위치를 파악.

- 그 다음 입력 값을 Query Key Value로 복사하여 fully connected layer를 통과. 이 과정에서 중간의 Multi-Head Attention layer를 보면, Linear layer를 통과하여 같은 값의 Query key Value가 달라지게 된다.

- 따라서, Multi-Head Attention layer를 통과하였을 때 서로 다른 위치에서, 여러가지 기준으로 정보를 뽑아오게 되어 다양한 관점에서의 정보를 얻을 수 있다.

- Multi-Head를 가지기 때문에 병렬처리 또한 가능.

Attention Score를 구하는 방법.

- Multi-Head Attention layer 내부의 Scaled Dot-product Attention layer를 통과하여 단어들 사이의 연관성을 알 수 있다. 이 과정을 수식으로 표현하면 위의 수식과 같다.

- QueryKey의 행렬 곱을 Scale 해준 후 Softmax() 함수를 통해 확률 값으로 만들어 준 후 Value 값을 곱하는 과정이 Scaled Dot-Product Attention layer의 수식. Scaled Dot-product Attention layer가 단어들의 연관성을 학습하는 과정은 아래와 같다.

- 실제로 이 과정이 ViT에도 적용 될 경우 단어가 하나의 패치로 바뀌는 것이기에 미리 보고 가면 좋을 것 같다.

CS224U (Natural Language Understanding)에서 설명하는 Transformer는 아래와 같다.

 

Stanford CS224U Natural Language Understanding 강의요약 - XCS224U: NLU I Contextual Word Representations, Part 2: Transforme

이전 강의 요약 - XCS224U: NLU I Intro & Evolution of Natural Language Understanding, Part. 1 I Spring 2023 이전 강의 요약 - XCS224U: NLU I Course Overview, Part. 2 I Spring 2023 이전 강의 요약 - XCS224U: NLU I Contextual Word Representati

187cm.tistory.com


Q. How to calculate Attention score?

예시 1) I am a student 라는 문장과 I 라는 문장과의 Attention Score를 구하는 방법. (1:K) example 

- I am a Student라는 4개의 토큰을 가진 문장에서의 Query를 I, Key를 I am a student라는 문장이라고 할 때, 처음에는 값이 같았으나, Linear layer를 통과함으로써 값이 달라졌다. 따라서 위와 같은 값을 가진다고 하자.

- 우리는 위에서 본 Attention score를 구하는 수식을 바탕으로 Query (I token) 과 Key (I am a student an 4개의 토큰)의 Dot-product연산을 수행하게 된다.

- 그러면 우리의 I 토큰, 즉 Query는 모든 key (자기 자신)에 대해서 확률 값, scalar 값을 알 수 있게 된다. 왜냐하면 우리는 softmax를 통과하니까! 0-1 사이의 값이 나올 것이다. 

- 위에서는 dk가 2인 예시를 사용했다. 이 vector의 크기에 따라 Gradient가 너무 커지거나 작아지는 문제가 발생 따라서 우리는 scaling을 진행.

- 실제로 I am a student는 아니지만 Hello I love you 라는 문장이 각각 Query key로 들어가는 K:K 예시일 경우 우측의 이미지와 같은 결과가 나올 것이다.

- 이어서, 위의 그림과 같이 Query와 Key의 행렬 곱을 한 결과를 확률 값으로 만들어 준 후이 값을 다시 Value(자기 자신)에 곱해줌으로써 자기 자신과 연관성이 있는 단어는 강조되며, 연관성이 없는 단어는 낮은 확률 값이 곱해지므로 단어들의 연관성을 학습하는데 도움이 됩니다.

 - 결국 Query Key Value를 통해 자기 자신과 같은 행렬에 대해서 연관성을 확인 후, Attention Score를 자기 자신에 적용해주기 때문에 Self Attention이라고 불린다.


1-2. Related Work (Computer vision 분야에 적용된 Self-Attention) 

Non-local Network

- ResNet의 저자 Kamming He가 교신저자로 들어간 논문. Video processing에 Attention을 적용했다고 볼 수 있다.

SENet

- Channel Attention 이라고도 불리는 논문이다. 이 SE-ResNet이 현재까지 ResNet중 가장 성능이 좋았다고 기억한다.

On the Relationship between Self-Attention and Convolutional Layers.

- ViT가 가장 영감을 많이 받은 논문 중 하나.

- 아래의 링크에 3가지를 다 정리해두었다. 이거까지 한번에 쓰면 내용이 길어지니 궁금하면 아래에서 보고 오자.

 

Vision 분야에 적용된 Self-Attention 알아보기 [Non-local Neural Networks, SENet, On the Relationship between Self-Atten

1. Non-local Neural Networks 제목 : Non-local Neural Networks 저자 : Xiaolong Wang , Kamming He et al 소속 : Carnegie Mellon University & Facebook AI Research 학회 : CVPR2018 인용 : 9010 (2023.09.24 기준) 링크 : https://arxiv.org/abs/1711.079

187cm.tistory.com


2. Introduce

왼쪽 이미지는 224x224에서 16x16으로 짤렸다는 것을 보여주기 힘들기에, 실제 imagenet dataset을 가지고 시각화.

- ViT(Vision Transformer) Architecture최대한 원본 Transformer와 유사한 모습을 보여준다. ViT 저자는 네트워크의 구조를 변경하였을 때, 하드웨어적으로 이점이 없었기 때문에 최대한 Transformer의 구조와 유사하게 가져갔다고 설명.

- 원본 Transformer와 유사한 구조를 가져가기 위해 이미지를 단어처럼 취급하는 Related Work의 3번째 방식과 비슷한 방법을 사용. (자세한 것은 위의 링크 참조)

++ 사실 필자가 느끼기엔 위의 저자가 2x2를 32x32이미지에 시도가 아닌 16x16을 그대로 224x224에 시도했으면 ViT와 유사한 논문이 나오지 않았을까 생각한다. 그치만 Dosovitskiy가 이 ViT를 JFT-300M 까지 끌어와서 가능하게 한 것이 의의가 크다고 생각한다.

- 가장 큰 차이점이 있다면 32x32 크기의 CIFAR10 Dataset 이미지에서 ImageNet 224x224 크기의 이미지를 사용. 

1. 이에 따라 2x2의 크기의 Patch에서 이미지를 위의 우측의 이미지 처럼 16x16 크기의 Patch로 큰 사이즈의 패치 사용.

2. Linear Projection을 통해 이미지를 1차원으로 flatten.

3. patch에 위치정보를 더해주는 Positional embedding. (VIT는 절대위치를 1D로 더해줌. * 표시는 CLS Token) (2-2 참조)

4. 3의 과정 후, TransformerEncoder로 들어가며, Encoder는 왼쪽의 우측 그림처럼 이루어져 있다.

5. Encoder는 앞에서 봤던 Transformer와 매우 유사하다. 차이가 있다면 Norm layer가 먼저 등장하였는데, 이 이유에 대해서는 뒤쪽에서 설명 (2-4 참조)

6. Multi-Head Attention을 진행하여 패치들 간의 연관성을 학습 후 MLP layer를 거쳐서 MLP의 0번째 위치 정보를 가지고 있는 CLS 토큰을 활용해서 Classification을 진행.

2-1. VIT Linear Projection of Flatten Patches

- 가장 먼저 이미지를 패치 사이즈인 P*P 크기로 쪼개서 1차원으로 펴는 과정

- VIT는 P(Patch resolution) 크기를 16으로 설정하여 224x224크기의 큰 이미지에 대해서도 Self attention이 가능하게 함.

- 따라서 Linear Projection에서 2차원 이미지의 Shape이 다음과 같을 때, [B C H W] 크기의 이미지를 [B N (P*P*C)]로 쪼개 1차원으로 피게 됩니다. (N은 패치의 수, P*P*C는 패치의 크기).

+ Einops 라이브러리를 활용하면 einops.rearrange 라이브러리 사용해서 그대로 구현하면 된다.

++ 추가적으로 논문에서는 1차원으로 Flatten 작업과 동시에 학습을 진행한다고 되어있다. 따라서 Patch로 쪼개고 Flatten 시킨 다음 linear layer를 붙여주는 형식으로 구현하거나, 혹은 Conv layer를 Stride = Filter size = 16으로 해서 잘라도 된다.

- 위의 그림으로 설명하면 다음과 같다. 그냥 Patch 단위로 Split이 아니라 [Split + Fc layer] 혹은 Conv layer가 붙는 다는 것을 명심하자. 

++ [196, 1, (16x16x3=768)] 에서 16x16x3 부분을 RGB로 나누어서 표기했지만, 실제로는 R|G|B 가 아니라 R-G-B-R-G-B .. 의 반복이다.

- 우측 그림은 실제 ViT에서 Appendix에 첨부한 Conv2D Embedding filter의 모습. Vertical, Horizontal한 filter 뿐 아니라, 중앙에 있는 물체를 중심적으로 보는 filter도 존재.

2-2. VIT Positional Embedding

- Patch Positional Encoding을 추가하는 과정.

- 이미지를 Patch 단위로 나누고 Flatten을 해주는 과정에서 1차원 vector로 변환(Flatten)하기 때문에 위치 정보가 사라지게 된다.

- VIT에서는 기존의 CNN을 연산을 Non-convolutional Transformer Encoder로 대체했기 때문에 위치 정보의 손실은 성능 저하로 이어지게 됩니다. 따라서 패치 단위로 위치 정보를 더해주는 과정을 거치게 된다.

- 2D 이미지를 학습하는 Image Classification Task에서 x,y 좌표에 대한 Positional encoding을 해주어야 하지만 VIT의 저자는 1D Positional Encoding을 진행.
- 아래는 VIT 저자가 Appendix에 첨부한 표입니다. Positional encoding을 하지 않았을 때와 했을 때의 성능 차이는 명확하며, 1D Positional Encoding이 2D보다 더 좋은 성능을 보이고 있다.

- 저자는 2D Positional Encoding을 할 경우 224x224 크기의 전체 이미지를 봐야하기 때문에 봐야할 정보가 1D Positional Encoding(14x14)보다 더 많아져서 오히려 성능 저하가 생긴다고 서술하고 있다. (아래 그림 참조)

++ 아무래도 2D는 위치정보는 모델을 더 복잡하게 만들어서 문제가 생기지 않을까? 생각한다.


2-3. VIT Positional Embedding (CLS Token)

 

- CLS token (Special Classification token)BERT (Bidirectional Encoder Representation Transformer)에서 사용된 토큰으로, BERT에서는 CLS 토큰과 제일 뒤에 SEP 토큰을 통해 문장 Classification을 수행한다.

(SEP 토큰은 여기서 생략, 빨간 네모의 * 표시가 CLS Token)

- BERT의 원리를 설명하기 위해 상단의 그림을 참고하면, CLS 토큰은 문장의 시작을 알리는 토큰이지만, 학습이 되고 난 이후의 CLS 토큰은 모든 문장을 요약한 토큰이 된다.

- BERT에서는 문장의 정보를 요약한 CLS 토큰을 이용해서 Classification을 진행. VIT에서도 BERT와 동일한 용도로 CLS Token을 사용.

- CLS Token이 학습이 잘 되었다면, CLS 토큰은 이미지의 표현을 요약하는 Token이 될 것.

- 논문에서는 CLS Token 대신 Global Average Pooling 이후 Fully connected layer를 사용할 수도 있지만 이 경우 성능을 유지하기 위해 Learning rate를 수정하여야 한다고 함. 반면 CLS Token을 사용할 경우 추가적인 Learning rate수정 없이 성능을 유지하며 학습이 가능

2-4. VIT Encoder

 

- 먼저 기존 Transformer와 비교해볼 부분은 Norm layer가 기존 Transformer와 비교했을 때 앞으로 빠져온 것이다.

- Learning deep transformer models for machine translation. In ACL, 2019 (Qiang Wang)가 제시한 논문에서는 Layer Norm을 앞으로 빼고 Residual block을 사용했을 때 더 성능이 좋았다는 논문을 제시. 따라서 저자는 이 구조를 채택

- 전체적인 흐름은 다음과 같다.

1. (B, 197, 768) 크기의 Tensor가 입력으로 들어가 Norm Layer를 거치게 되며

2. 그 다음 앞서 설명했던 Transformer의 Multi-Head Attention과 유사하게 Multi-Head Attention을 수행

3. Residual 연산 후 다시 앞부분과 동일하게 Norm Layer를 거치며

4. 그 다음 MLP Layer를 거쳐 Residual 연산을 진행.

수식으로 나타내면 위의 1-4와 같다.

  1. Patch로 쪼갠 후 1차원으로 Flatten과 동시에 학습가능한 Projection 진행 및 Positional Encoding과 CLS Token을 더해줍니다.
  2. Norm layer -> Multi-Head Attention -> Residual connection 연산 진행.
  3. Norm layer -> MLP Layer -> Residual connection 연산 진행
  4. 학습이 잘 된 CLS Token을 가지고 Norm Layer통과

4번을 통과한 Tensor MLP Head에서 Fully Connected layer 통과하며 Classification을 진행.

++ (5,6,7,8번의 수식은 Multi-Head Attention layer(2번의 상세)의 수식이며 Transformer와 유사합니다.)

 

5. Layer Norm을 통과한 X는 Query, Key, Value로 복사되어 z(fully connected layer)와 곱해집니다.

6. Query와 Key의 행렬 곱을 dk 로 나누어 Scaling을 진행 후, Softmax()를 통해 확률 값으로 변환합니다.

7. Value와 곱을 통해 원래 픽셀에 Self Attention을 통해 획득한 확률 값을 곱해줍니다

8. Multi-Head Attention이므로 Head로 흩어졌던 친구들 다시 Concat 후 Fully connected layer 통과. (Transformer도 MHA Concat 후 Fully Connected layer통과합니다)


2-5. 잠시 정리 및 Self-Attention layer in ViT

- 위에서 진행한 Linear Projection과 Positional Embedding + CLS Token을 더하고 나면 (B, 197, 768) 크기의 Tensor가 나오게 되며, 이 Tensor는 Transformer Encoder를 통과하게 된다.

- 이 R,G,B 가 따로 따로 Concat이 되는 것이 아닌 R-G-B-R-G-B .. ect 형태로 768개의 1D vector를 구성하게 될 것이지만 위에서는 시각화의 편의를 위해 이렇게 작성했다.

- (B, 197, 768) 크기의 Tensor는 Query, Key, Value라는 이름의 3개의 Tensor로 복사되며, 그 다음 Head의 수만큼 Data 재배치가 된다. (Head 수는 8이라 가정)

- 아직 Head에 대한 내용을 설명하진 않았지만, Multi-Head Attention layer 봤잖아? 위에서 그러면 이해해보자! 어짜피 뒤에서 설명 하겠지만..

그런김에 Multi-Head Attention layer에서 발생하는 Self-Attention을 알아보자

- 그 다음 Query와 Key의 내적을 통해 Patch끼리 행렬 곱을 수행하여 Patch끼리의 연관성을 확인. 그 다음 Scaling 및 Softmax() 함수를 통해 확률 값으로 만들어 준 후 원래 차원과 동일한 Value와 곱하여 Patch 별 Self-Attention을 수행.

- Transformer에서 본 Self Attention 연산을 수행하면 위와 같다고 볼 수 있다.

 

- 사실 이렇게 Block 단위로 표기 한 이유는 Code 구현한 내용을 설명하려고 만든 내용이다. Code 구현은 따로 올릴 예정이다.


2-6. VIT MLP Head

- 앞서 설명한 과정을 바탕으로, CLS Token이 학습이 잘 되었다면 Patch의 특징을 잘 요약하는 Token이 되었을 것, BERT와 동일하게 VIT에서도 따라서 이 Token을 가지고 Classification을 진행

- 다시 보면 이제 CLS 토큰 위에 MLP Head가 달린 것을 볼 수 있다.


2-7. VIT Training

- VIT의 성능을 내기 위해서는 pre-trained 모델이 필수적. 그 이유는 VIT가 CNN과 비교하였을 때 Inductive bias*가 훨씬 적기 때문이라고 저자는 얘기한다.

++ Inductive bias* : 주어지지 않은 입력을 예측하는 것. Unseen data를 예측하기 위한 추가적인 가설

- CNN은 지역적인 특징을 가진 네트워크이며, 2개의 차원이 이웃한 구조이고, 각 Layer는 translation equivariance*하기 때문에 이미지에 대한 Inductive bias가 충분하다.

++ translation equivariance*: CNN filter가 한 칸 옆으로 옮겨져도, 같은 weight을 가지는 CNN filter의 출력은 같은 것.

- 하지만 ViT의 경우는 Encoder를 MLP와 Self-attention으로 구성하여 기존 CNN의 Convolution layer를 대체하므로 Inductive bias가 적어 3억장의 JFT-300M Dataset을 사용하여 Pretrain을 진행. 

- 우측 상단의 표를 보자, 위에서 설명한 Inductive bias에 의해 ImageNet (130만장), ImageNet 21K (1400만장)을 사용하여 Pretrain을 진행하였을 때 성능을 비교한 표를 보면 JFT-300M에서는 90%에 가까운 성능을 보여주지만, ImageNet에서는 BiT(Resnet) 보다 낮은 성능을 기록.

- 또한 JFT Dataset을 바탕으로 사전 학습 데이터의 성능을 비교하였을 때, 10M만 사용하여 Pretrain 진행 시 ResNet보다도 성능이 낮지만, 300M을 모두 사용하였을 때는 좋은 성능을 보여주고 있는 것을 확인.

- ViTDataset별 성능은 다음과 같다. 기존의 Noisy Student, BiT-L 모델과 VIT의 성능을 비교하였을 때 VIT의 학습 시간이 더 적게 소요되며, 정확도 또한 더 높은 것을 알 수 있습니다. 또한 Papers with code에서 성능을 확인하였을 때 CIFAR10 dataset에 대해서 SOTA를 유지하고 있는 것을 볼 수 있습니다.


3. 그 외 Attention Map Visualization

- 그 외에 2편에서 소스코드의 설명과 함께 다룰 Attention Map visualization이다. 왼쪽 아래 이미지는 각각 Query, Key, Value, Attention Map, 실제 이미지 이다.

- 그리고 우측의 이미지는 시각화 결과이다. 

- 코드만 빠르게 보고싶다면 Github를 참조하자.

 

GitHub - younghoonNa/ViT_Attention_Map_Visualization: ViT Attention map visualization (using Custom ViT and Pytorch timm module)

ViT Attention map visualization (using Custom ViT and Pytorch timm module) - GitHub - younghoonNa/ViT_Attention_Map_Visualization: ViT Attention map visualization (using Custom ViT and Pytorch timm...

github.com

 

반응형