어제 Base Transformer의 파라미터 수를 정리해봤는데, 이번엔 보너스 느낌으로 Base BERT의 파라미터 수를 정리해보고자 한다. BERT의 개념 정리는 나중에 하도록 하겠다.
https://tiabet0929.tistory.com/86
[LLM] Attention is All You Need 의 Base Transformer 파라미터 수 계산
오랜만에 논문을 다시 읽다가 파라미터 수에 꽂혔다. 여기서 베이스 모델의 파라미터가 65M이라고 나와있길래, 재미삼아 GPT에게 물어봤는데, 당연히 위에 사진만 보여주면 환각 현상 때문에 65
tiabet0929.tistory.com
BERT는 인코더로만 이루어져 있는 모델이라 파라미터 수 계산이 훨씬 편하다. 또한 몇 가지 공식을 사용해 계산을 훨씬 빠르게 만들 수 있다.
이전 포스팅의 베이스 트랜스포머에선 모델의 차원이 512, 피드포워드 네트워크의 차원이 2048로 총 4배의 차이가 났다. 즉, 모델의 차원을 $d_{\text{model}}$이라 하면 피드포워드 네트워크의 차원은 $d_{\text{ff}} = 4 d_{\text{model}}$이 된다.
BERT 또한 마찬가지로, 모델의 차원이 768, 네트워크의 차원은 3072로, $d_{\text{ff}} = 4 d_{\text{model}}$이다.
또 하나 응용할 사실은, 어텐션 헤드가 몇 개든 간에 파라미터 수에는 영향을 끼치지 않는다는 것이다. 상기한 포스팅에서 이 이유를 설명했으므로 여기선 생략한다.
이를 응용해서 수식 계산을 더욱 간단하게 만들어보자.
1. 임베딩 단계
임베딩 단계에서 단어 사전의 크기 $V$는 BERT에선 30522라고 정확히 명시되어있다.
각각의 단어를 모델의 차원으로 늘려서 임베딩 벡터로 만들어야 하므로, 임베딩 단계에서는 $V*d_{\text{model}}$ 만큼의 파라미터가 존재한다.
2. 인코더 단계
인코더 단계의 Multi-Head Attention부터 살펴보면, Query, Key, Value에서 각각 $d_{\text{model}}^{2}$만큼의
파라미터가 발생, 그리고 이를 Projection할 때 또 $d_{\text{model}}^{2}$만큼의 파라미터가 발생하므로 총
$4 d_{\text{model}}^2$ 의 파라미터가 발생한다. FFNN 단계도 살펴보면, $d_{\text{ff}} = 4 d_{\text{model}}$ 이므로
$2 d_{\text{model}}* d_{\text{ff}} = 8* d_{\text{model}} ^2$만큼의 파라미터가 존재한다. 따라서 인코더 단계에선 총 $12 d_{\text{model}} ^2$만큼의 파라미터가 존재한다. 그런데 이런 레이어가 총 12개니까, 최종적으로 $144 d_{\text{model}} ^2$만큼의 파라미터가 존재한다.
3. 최종 단계
BERT의 최종단계에선 BertPooler라는 Linear Layer가 하나 존재하는데, 이는 [CLS] 토큰의 임베딩을 구하기 위해 존재한다. 즉 Bert의 주요 Task 중 하나인 [CLS] 토큰의 임베딩을 구하기 위해 존재하는 레이어다. 따라서 $ d_{\text{model}} ^2$ 만큼의 파라미터가 존재한다.
따라서, Base BERT에는 총 $V * d_{\text{model}} + 145* d_{\text{model}} ^2$ 만큼의 파라미터가 존재하며, 여기에 V와 $d_{\text{model}}$ 을 대입하면 논문의 110M과 비슷한 수치인 108M이 나오게 된다.
결론 : $ d_{\text{model}} $로 치환하고, 굳이 어텐션헤드로 나눠주는 과정을 생략하면 파라미터 계산이 빠르다.
'LLM' 카테고리의 다른 글
[LLM] ChatGPT 4o 이미지 생성 모델, 어떻게 만들었는지 원리 탐구 (4) | 2025.03.28 |
---|---|
[LLM] LG의 LLM EXAONE Deep 사용 후기 및 딥시크 R1, OpenAI o1 과의 비교 (6) | 2025.03.18 |
[LLM] Attention is All You Need 의 Base Transformer 파라미터 수 계산 (2) | 2025.02.11 |
[LLM] LLM으로 Tabular Data 학습해보기 3 - Langchain으로 데이터 증강하기 (8) | 2024.11.13 |
[LLM] LLM으로 Tabular Data 학습해보기 - 2. 이진분류 (경정데이터분석) (7) | 2024.09.29 |