본문 바로가기
Deep Learning (Computer Vision)/Model Compression and Optimization

Pytorch Tutorial보다 친절한 Pytorch Pruning Tutorial (2)

by 187cm 2023. 6. 1.
반응형

Pytorch Pruning 1편 보러가기.

 

Pytorch Tutorial보다 친절한 Pytorch Pruning Tutorial (1)

Colab 자료 - https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/f40ae04715cdb214ecba048c12f8dddf/pruning_tutorial.ipynb#scrollTo=mRMctJEUvqbS 번역본 - https://tutorials.pytorch.kr/intermediate/pruning_tutorial.html -

187cm.tistory.com

이번시간에는 지난시간에 이어서 Pytorch Pruning Tutroial을 이어서 진행해보려 한다.

Iterative Pruning

- 1에서 적용한 Pruning에 이번에 적용하는 pruning을 한번 더 적용했을 때의 얘기를 하고 있다.

- 즉, 중복으로 적용했을 때에도 잘 동작하는 모습을 보여준다.

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

- ln_structed 모듈을 사용하여 module의 weight에 대해 pruning을 진행합니다. 0.5 = 50%정도의 Pruning을 진행할 것이며, n=2는 L2_norm을 기준으로 0번째 dimension에 대해 pruning을 진행합니다. 

- 0번째 차원에 대해 아래와 같이 Pruning이 된 것을 볼 수 있습니다. 조금 더 친절하게 설명드리자면, module.weight.shape = (6,1,3,3)입니다. 이 때, 여기서 0번째 dimension 즉 빨간색 네모박스들을 기준으로 pruning을 진행합니다.

- module.weight를 실행하면, 파란색 박스에 이어서 빨간색 박스들이 전부 0으로 바뀌는 것을 볼 수 있습니다.

 

- 다음과 같이 직접 L2 Norm을 정의하여 살펴보면, 실제로도 4,5,6번째 Gradient의 Norm이 가장 작은 것을 볼 수 있습니다.

- 그리고 pruning에 대한 hook은 prune.PruningContainer에 저장되어있습니다. 

- 모델에 dictionary 형태로 저장된 key()를 추출하면 아래와 같습니다. 

- 기존의 weight와 bias에 conv1.weight_orig, conv1.bias_orig가 추가되었습니다.

- 앞에서의 pruning과 마찬가지로, weight와 해당하는 mask가 존재하는 것을 확인할 수 있습니다.

- 또한 해당하는 부분이 0으로 바뀌는 것을 볼 수 있습니다.

- Remove 모듈을 이용해 pruning을 진짜 적용했을 때의 그림이다.

- 왼쪽은 첫번째 L1 pruning만 적용되었을 때이며, 두번째는 아래의 빨간색 부분이 L2 Pruning에 의해 제거된 그림이다.

Pruning multiple parameters in a model

- Pruning을 한 층에 대해서만 하는 것이 아닌, 여러 층에 대해서도 적용할 수 있다.

- 아래와 같이 LeNet을 재정의하여 파라미터를 초기화 한 후, L1_norm을 이용한 pruning을 진행한다

- The pruning method can be applied to multiple layers. (* Not just a single layer)

- We initialize the LeNet paramters, follwed by performing pruning using the L1 Norm.

Global pruning

- 이제는 Global pruning에 대해서 알아보도록 하자. 이전꺼는 local pruning.

- 아래와 튜플 안에, 튜플이 들어가도록 정의를 한다. 2번째 튜플에는 Pruning 하고싶은 weight 혹은 bias에 대해 정의한다.

- 그리고 'prune.global_unstructured' 함수를 사용하여 parameter는 위에 사용한 tuple을,  pruning_method에는 어떤 방식으로 pruning을 진행할 것인지, 그리고 얼마나 pruning 할 것인지를 정의한다.

- print 함수를 사용하여 각 layer 당 얼마나 pruning이 되었는지 확인.

- 그리고 마지막에 이를 다 더해서 총 20%에 가까운 Pruning이 진행되었는지 확인

- pruning 결과는 20%에 딱 맞게 된 것을 볼 수 있으며 주로 Fully Connected Layer 부분이 많이 된 것을 볼 수 있다.

반응형