본문 바로가기

옵티마이저 설명 SAM: Sharpness-Aware Minimization (2020, training loss를 이용한 일반화 성능 올리기)

반응형

관련글


옵티마이저 정리 및 논문 리뷰
 

Optimizer Optimizer_Name (연도)


논문 제목: Sharpness-Aware Minimization for Efficiently Improving Generalization
연도: 2020
링크: https://arxiv.org/abs/2010.01412
관련 개념: Generalization, Geometry of loss landscape, 모집단, 표본
 

서론


블로그를 포스팅하는 2023년도 기준으로, 현대의 DNN 기반의 인공지능은 수많은 분야(task)에서 기존의 기계학습 방법보다 뛰어난 성능을 보인다. 이 DNN 기반의 인공지능은 파라미터의 개수가 늘어남에 따라 성능이 올라가는 경향을 보이는데(반드시는 아님), 최근의 인공지능은 필요 이상의 파라미터를 가지고 있다고 여겨진다. 이를 overparamterization이라고 부르는데, 필요 이상의 파라미터를 사용하면 state-of-art(최신, 최고) 성능을 보이는 경우가 많다.
인공지능은 training dataset이라는 데이터를 학습하고, validation 혹은 test dataset을 통해 실제 성능을 검증한다. training dataset과 validation / test dataset 사이의 성능 차이를 generalization gap이라고 부르는데, 이 간극을 줄이는 것은 상당히 어려운 작업이다.(독자들도 알 것이라 생각한다.)
많이들 사용하는 Adam의 경우 training dataset에는 성능이 좋지만, validation/test dataset에서는 성능이 안 좋게 나오는 경우가 많다. 요즘에는 Adam과 SGD를 비교하는 논문들이 나오며, 왜 SGD가 Adam보다 일반화 성능이 좋은지에 대한 연구가 진행되고 있다. 참고 논문. Image 계열에서 유명한 EfficientNet만해도 Adam이 아닌, RMSProp로 학습된 것을 보면 특정 optimizer가 만능은 아니라는 것을 보여준다.
 

그럼에도 불구하고 training dataset만을 이용하여 general한 성능을 이끌어야 한다!

우리에게 주어진 데이터는 training dataset이며, 우리는 해당 dataset만을 이용해 general한 성능을 이끌어 내야 한다. 어떻게 가능할까?
기존 방법론

  1. Dropout: 확률적으로 모델 일부를 학습시키지 않는 기법
  2. Batch normalization: 모델이 한 방향으로 편향되지 않도록 normalization 기법을 이용하여 모델을 학습시키는 기법
  3. Data augmentation: training dataset에 변형을 줘 학습 데이터를 늘리는 방법

상기 방법론은 현재에도 많이 사용되는 기법으로, 그 효과가 경험적으로 어느정도 증명된 기법들이다.(하지만 만능은 아니며, 반드시 성능이 올라가는 것도 아니다.)
이번 포스트에서 소개할 SAM(Sharpness-Aware Minimization)은 위의 방법론들과는 다른 기법으로,  loss를 관찰하는 방식으로 동작한다. 이것이 무슨 의미일까?
인공지능 분야에서 사용되는 loss landscape는 쉽게 말해 loss 그래프를 의미한다. 그래프를 지형이라는 단어를 사용한 것뿐이다. 현대의 일반화에 대한 연구의 많은 의견은 'training loss landscape에서 loss landscape가 평평할 수록 일반화 성능이 올라갈 가능성이 높다'다. 즉, training loss를 그래프로 그렸을 때, 인공지능 모델이 loss가 평평한 부분에 있다면, 그 모델의 일반화 성능은 높을 가능성이 높다는 의미다. 다른 측면에서 말하면, 뾰족한 부분을 알고, 그런 부분을 줄여야 한다는 의미다.
어떻게 평평한 loss landscape를 찾을지는 SAM의 방법론을 자세히 봐야만 알 수 있다. 이는 아래와 본문에서 계속해 소개하도록 한다.

SAM의 특징

  • SAM은 loss value를 줄이면서(training 성능 올리기) loss의 뽀족함(sharpness)를 줄인다.
  • SAM은 모델 파라미터의 주변에 존재하는 파라미터들을 확인하면서 low loss value와 sharpness를 확인한다.
  • -> 현재 파라미터의 loss가 N 정도의 loss고, 주변 파라미터들의 loss도 N정도라면 이는 현재 파라미터에서 loss가 평평하다는 의미이기 때문.

 
실제 SAM 알고리즘을 보면, SAM은 파라미터 그 자체를 찾는다기보다는, SAM의 목적에 맞는 파라미터로 향하는 gradient를 찾는 것으로 보인다. (Section 2의 Algorithm 1: SAM algorithm 참고)
SAM의 사용법은 매 iteration마다 optimizer를 적용한 후, 사용하는 것으로 보인다.(GitHub 참고 (PyTorch, TensorFlow))  

 
 

본론


SAM은 다음과 같은 문제를 해결하여 적당한 파라미터를 찾는다.

Eq (1). Sharpness-Aware Minimization (SAM) problem.

p는 하이퍼파라미터로 0 이상의 상수이며, p-norm에서 말하는 p로 생각된다. 람다 또한 하이퍼파라미터이며, 해당 항은 SAM 논문의 Theorem 1이 h 항을 의미한다. S는 현재 학습하고 있는 training dataset(or Batch)를 의미한다.
Eq (1)에서 L SAM(w)는 현재 파라미터 주변의 파라미터에서 loss가 가장 클 때의 loss를 의미하며(Sharpness와 관련돼 있다고 생각), 이를 전체적으로 w에 대해 낮추는 방향으로 동작함을 보이고 있다.(low loss value).
L SAM(w)을 최소화하기 위한 방법으로 논문에서는 L SAM(w)에 대해 SGD를 적용한다고 나와 있다. 이를 위해 입실론에 대해 L SAM(w)를 1차 Taylor 전개를 진행한다.

L SAM(w)에 대해 Taylor 1 차 전개 진행

논문에서는 위의 수식 자체를 해결하기 보다는, 고전적인 optimization 방법론 중 하나인 dual space에서 진행하는 방식으로 진행한다. (dual space에서 최적화한 값과 원래 space에서 최적화한 것이 동일하기 때문)

Eq (2). The classical dual norm problem.

여기서 q는 1/p + 1/q = 1를 만족한다.
위의 Eq (2)를 Eq (1)에 적용한다면, 다음과 같은 미분방정식이 나온다.

Eq (2)를 Eq (1)에 적용

물결 표시는 Taylor approximation(1차 전개)에 의해 성립된 것이며, 첫 번째 등호는 미분의 chain rule을 적용한 것으로 보인다. 두 번째 등호는 미분의 선형성을 시용하여 d(w+e^(w))을 분리한 모습이다.
위의 수식 자체는 풀기 어려워 보이지만, 우리가 사용하는 TensorFlow, JAX, PyTorch는 자동미분이라는 어마어마한 기능을 제공하기에 쉽게 계산할 수 있다.
위의 수식의 마지막 항에는 e^에 대한 w의 미분과 Ls(w)의 미분이 들어있다. 여기서 e^는 Eq (2)에 알 수 있듯 Ls(w)에 대한 식이기 때문에, 결론적으로 마지막 항은 Ls(w)에 대한 Hessian이라고 할 수 있다. -> 즉, computation이 어마어마 하다. SAM 논문에서는 이 문제를 해결하기 위해 깔끔하게 마지막 항을 제거하여 최종적으로 다음과 같은 수식을 도출했다. (신기하게도 Hessian 항을 추가하면 성능이 하락한다고 한다. 논문의 C.4 참고. 이는 후속 연구로 진행할 수 있을 것 같다.) 이로써 SAM은 computation과 정확도 둘 모두를 잡았다.

Eq (3) 최종 SAM 수식. Hessian 항이 사라졌음을 알 수 있다.

결론


SAM 방법론의 바탕 지식에는 'training loss가 평평할 수록 일반화 성능이 올라갈 가능성이 높다'이다.
SAM은 loss landscape가 평평한 혹은 뾰족하지 않은 부분을 찾는 방식으로 generalization 성능을 올리려 노력한다. Hessian 항을 제거하여 효율적인 computation을 얻었으며, 부가적인 효과로 성능 효과를 얻었다. SAM을 EfficientNet에 적용하여 SOTA를 달성하기도 하였다.
SAM의 또다른 측면으로는 Label에 noise가 존재하는 상황에서도 강건함(robustness)를 유지한다고 한다.
 

장점

기존 방법론과 계열이 다른 새로운 접근법이다. (기존 방법론과 같이 사용 가능하다)
일반화 성능을 이끌어 낼 가능성이 있다.
EfficientNet에 대해 SOTA를 찍었다.
Google에서 만들었다.
optimizer와 같이 사용할 수 있다.

단점

한 iteration에서 두 번의 optimization이 필요하다 (한 번은 기존 optimizer, 한 번은 SAM을 위한 optimization)
'training loss가 평평할 수록 일반화 성능이 올라간다'에 대한 찬반이 아직도 존재한다.
 

 
 
 

참고


https://arxiv.org/pdf/2010.01412.pdf
https://proceedings.neurips.cc/paper/2020/file/f3f27a324736617f20abbf2ffd806f6d-Paper.pdf
 

반응형