0. Intro
우리는 Rethinking the value of the pruning에서 pruning에 대한 많은 인사이트를 얻을 수 있었다. 앞으로 나올 내용들이 이에 기반한 내용이니 이하 논문이라고 표현하도록 하겠다. unstructured pruning은 weight의 숫자를 급격하게 줄일 순 있지만, 결과로 도출된 희소행렬이 하드웨어 혹은 특정 라이브러리의 도움 없이는 inference time과 메모리 자체에는 변화를 가져다 주지 못했음을 확인할 수 있었다.
https://do-my-best.tistory.com/80
이에 따라 structured pruning으로 눈길을 돌렸고, 앞서 언급한 논문의 결과를 바탕으로 흐름 및 구조가 automatic structured pruning 대신 효과가 유사하며, 더 명확하고 구현 난이도가 비교적 쉬운 Uniform structured pruning을 진행하였다. (뒤에서 자세히 알아보자)
SOLOv2가 Pytorch를 활용하여 구현된 API인 MMDetection을 확장한것이므로 앞으로 다룰 내용은 Pytorch와 MMDetection의 내용과 밀접한 연관이 있다.
1. Structured pruning
구현에 앞서 Structured pruning이란 무엇인가? 미리 정의된 네트워크에서 일정한 기준을 바탕으로 Kernel의 갯수를 줄이고, 줄어든 Kernel에 따라 Depth를 조정하는 방식으로 구조적인 pruning을 하는 것이다. unstructured pruning과 다르게 실제 모델을 깎으므로, 메모리를 아낄 수 있고, 더 빠른 Inference time을 기대할 수 있다.
처음에는 단순하게 Kernel의 갯수만 줄이면 될 줄 알았으나, Kernel의 갯수가 현재 layer output의 depth(즉, 다음 layer의 input depth)이므로 구조적으로 다음 input의 depth를 줄여주어야만 했다. 이때 아무거나 줄이는 것이 아니라 삭제한 kernel의 순서에 맞게 삭제해준다.
// 질문 ) 하지만 이것은 Fine-tuning을 위한 것일 뿐 Rethinking the value of the pruning에서 나온대로 scratch부터 학습을 진행한다면 의미 없지 않은가? 즉, 그저 kernel의 갯수와 input depth만 맞춰주면 되는거 아닌가?
이렇게 channel들을 줄이고, depth를 맞춰주면 된다. SOLOv2의 backbone은 Residual block을 사용하는 ResNeXt와 ResNet이다. 논문에서 나왔던것과 같이 이와 같이 얇고, 깊은 커널을 갖는 최신 모델들은 Automatic pruning과 Uniform pruning의 pruning경향이 매우 유사하므로 성능차이가 크지 않았음을 알 수 있었다. 가장 먼저 시도한 방법은 이에 착안하여 backbone만, 일정 퍼센티지만큼 uniform하게 channel을 pruning하는 것이었다.
backbone이 SOLOv2 파라미터의 약 65%의 양을 차지하므로 이는 합리적이라고 생각하고 모든 kernel을 1/4만큼 줄이는 것을 목표로 시작하였다.
가장 먼저 시도한 Backbone은 mAP가 가장 높은 ResNeXt 64 x 4d를 활용하는 모델이었다. 하지만 진행 과정에서 이는 해내기 어려울 것이라고 생각하였는데, 그 이유는 ResNeXt의 특성때문이었다.
ResNeXt는 ResNet의 Residual block에서 작은 shape을, 많은 cardinality로 구성한 모델이다. 하지만 이렇게 나누는 과정에서 과거에 pruning되었던 kernel의 위치를 유실하게 되고 cardinality를 정상적으로 구성할 수 없다고 생각하였다. (이 부분은 확실하지 않다. 어떻게 나누는지를 분석하면 될거같긴 하지만... )
// 질문 ) 하지만 scatch부터 한다고 하면 그냥 4크기의 depth를 갖고 있는 모든 cardinality의 input을 일정 비율만큼 동일하게 자르면 되지 않는가?
그래서 과감하게 버리고 ResNet으로 진행하였다.
2. Reconstruct model
그 과정에서 사용되었던 방법은 크게 두가지로 나눌 수 있는데, 다음은 그 방법이다.
1. Torch의 module API사용
pretrained model을 불러와서 바로 변경하는 방식을 사용한다.
model._modules['layer_name'] = th.nn.some_layer(...)
이렇게 하면 지정한 layer를 선언한 layer로 임의로 바꿀 수 있다. 물론 이는 임시적인것이므로
2. MMdetection의 Config file및 Model file수정
MMdetection의 Config file에는 어떤 것을 입력으로 받을지, 몇개의 block을 정의할지 등 모델 생성 및 training, test에 대한 정보들이 정의되어 있다.
또한 model file은 모델을 어떻게 구성할 지 직접적으로 나와있기 때문에 이를 활용하면 쉽게 구조를 바꿀 수 있다.
mmdetection이 어떻게 model을 표현하고 어떤 흐름을 갖는지 코드를 훑고, 그게 어떤 의미인지 파악하는데 시간이 좀 걸렸지, 그다지 어렵진 않다. 게다가 ResNet과 ResNeXt이 Bottleneck의 반복이고, 그 Bottleneck이 모듈로 표현되어 있으며 규칙이 있기 때문에 이 Kernel의 갯수와 Depth의 크기에 대한 규칙을 변경할 수 있다.
3. Prune weight
이제 weight를 load할 차례이다. 기존 pretrained model에서 channel별로 L1 norm을 구하고, 오름차순으로 정렬한 이후에, 상위 x%만큼 pruning mask를 생성한다. 이후, 이 pruning mask를 바탕으로 다음으로 전달할 kernel을 만든다.
흠... 개념 자체는 쉬웠지만 pytorch와 python의 특성때문에 꽤나 어려웠다. 특히 새로운 weight를 만들 때, 차원의 문제 때문에 원하는 대로 kernel tensor들을 이어붙일 수 없었다던가, depth를 pruning할때 차원문제가 생긴다던가, pytorch에선 parameter의 값이나 형태를 변경할 수 없다던가 (엄밀하게 말하면 값은 변경할 순 있다. 하지만 변경 이후 업데이트가 되지 않는 경우가 존재한다. 즉, weight값을 직접 배정으로써 바꿔주는게 아닌, tensor를 copy하거나 parameter를 새로 선언하는 방향으로 진행해야 한다.)
4. Load weight
그냥 불러들이면 된다. 하지만 아까 서술한 바와 같이 그냥 무작정 바꾸면 적용되지 않는 경우가 존재한다. tensor를 copy하거나 parameter를 새로 선언하고 값 뿐만 아니라 parameter 전체를 배정하는 방향으로 가야한다.
뭐 어떤 성질이 있는진 잘 모르겠다.... 그저 단순하게 배정만 하면 안된다는것 같다.
5. Fine-tuning or Learning
이제 불러들인 모델을 학습하는 일만 남았다. Rethinking the value of pruning에 따르면 structured pruning은 pruned weight에 효과가 있는게 아닌, 새로 만든 아키텍처의 구조가 영향을 준다고 한다더라. 그래서 scratch부터 학습을 하는것이 오히려 더 성능이 좋다고 하는데... 그럴꺼면 로드도 필요없고, pruning 기준도 필요없으며 구조만 바꿔주면 되는거 아닌가 싶긴하다.
6. Review
첫 과제다 보니 한 것에 비해서 꽤나 시간이 많이 걸렸다. 그 과정에서 그래도 모델에 대한 전반적인 이해도가 높아졌음을 느꼈으며, pytorch, mmdetection 등 API에 친숙하게 되었다.
'ML' 카테고리의 다른 글
FairMOT : On the Fairness of Detection andRe-Identification in Multiple Object Tracking 논문 리뷰 (0) | 2021.09.01 |
---|---|
인공신경망(SOLOv2)의 data pipeline의 구성과 Resized model의 효과 (0) | 2021.08.13 |
Network pruning (0) | 2021.07.27 |
HOG (Histogram of Oriented Gradient for human detecting) (0) | 2021.07.21 |
SIFT (Scale-Invariant-Feature TRansform)를 활용한 이미지 특징 추출 및 매칭 알고리즘 (1) | 2021.07.20 |