옵티마이저 Lookahead (2019, 안정적인 학습 도모)
관련글
Optimizer Lookahead (2019)
논문 제목: Lookahead Optimizer: k steps forward, 1 step back
연도: 2019
링크: https://arxiv.org/pdf/1907.08610.pdf#page=10&zoom=100,144,604
관련 개념:
서론
현재 나온 수많은 optimizer는 SGD에 그 기반을 두고 있습니다. 예를 들어 Momentum이라고 알려진 Polyak heavy-ball 알고리즘은 SGD에 관성이라는 개념을 도입했으며, Adam이나 AdaGrad 등은 개별적인 learning rate라는 개념을 SGD에 도입했습니다. 이러한 알고리즘은 안정적인 학습을 위해 hyperparameter를 직접 조정해야 합니다. 하지만 이러한 과정은 시간이 많이 들며, 사람의 노력이 많이 요구되는 작접이기도 합니다.
저자들이 발표한 Lookahead라는 기법은 기존의 optimizer에 적용할 수 있는 기법으로, 정확히 말하자면 optimizer라고는 할 수 없습니다. 다만, 안정적인 학습을 도와주는 추가적인 알고리즘이라고 할 수 있죠.
저자들은 CIFAR, ImageNet dataset에 ResNet50, ResNet152을 사용하고 Lookahead 기법을 적용했으며, LSTM 기반의 언어 모델과 Transformer-based 모델에도 Lookahead 기법을 적용했습니다. 결과는 기존의 optimizer보다 더 나은 수렴성을 보였으며, 때로는 일반화 성능도 높아졌다고 합니다.
본론
알고리즘 자체는 굉장히 단순합니다. K번 동안 weight를 update하고 업데이트된 weight와 기존의 K번 전의 weight 사이에 있는 weight로 최종 weight를 결정하는 것입니다. 쉽게 말하자면, original weight와 K번의 루프로 updated weight가 있을 때, original weight와 updated weight 사이 어딘가에 있는 값으로 최종 weight를 결정한다는 뜻입니다. 여기서 우리는 updated된 weight를 fast weight, 최종 weight를 slow weight라고 부릅니다. 이 과정을 여러번 반복하는 것이 Lookahead 기법입니다.
위의 알고리즘에서 theta는 fast weight로, 두 번째 for loop 덕분에 k 횟수 동안 update가 됩니다. 위의 알고리즘에서 fai는 slow weight로, k번 update된 fast weight과 alpha라는 hyperparameter와 연산되어 업데이트 되고 있습니다.
즉, fast weight는 매번 update되며, slow weight는 k번째마다 update되기 때문에 fast와 slow라는 수식어가 붙은 것 같습니다. 다른 식으로 해석하면 fast weight와 slow weight는 k번째마다 동일해집니다(synchronization).
Computational complexity
Lookahead 기법은 일정한 연산 복잡도를 가진다고 합니다. (연산 복잡도는 논문의 3page에 나와있습니다.)
Selecting the Slow Weights Step Size
알고리즘에서 추가된 hyperparameter에는 alpha가 있습니다. 이것을 slow weights step size라고 부르는데, 논문에서는 이것을 자동으로 결정해주는 알고리즘을 소개합니다.
Loss function인 L은 MSE 등의 계열을 사용했다 가정했으며, 그때의 최적의 alpha 값을 계산하는 방법이 위에 나와 있습니다. 만약 MSE를 Loss function으로 사용했을 경우, A는 Identity matrix가 되겠네요. 이것 외에 논문에서는 clip을 사용한 방법도 나와있습니다.
위의 그림은 Lookahead 기법을 적용한 optimizer와 그렇지 않은 SGD, Adam Polyak(=momentum)의 성능을 보여줍니다. 보시다시피 Lookahead를 사용한 것이 좋다고 나와있네요.
결론
장점
1. 간단하다. |
단점
1. weight를 전부 복사 해야하기 때문에 메모리 사용량이 기존 optimizers의 2배이다. (즉, proxy를 만들어야 하기 때문에 메모리 사용량이 기존 optimizers의 2배이다.) |
참고
https://arxiv.org/abs/1907.08610