본문 바로가기
NLP

[NLP Study] - LSTM

by Tiabet 2024. 3. 22.

이번 포스팅에선 LSTM에 대해 공부해보면서 어떻게 RNN보다 좋은 모델로 평가받을 수 있었는지에 대해 살펴보고자 한다.

 

참고자료

https://wikidocs.net/22888

 

08-02 장단기 메모리(Long Short-Term Memory, LSTM)

바닐라 아이스크림이 가장 기본적인 맛을 가진 아이스크림인 것처럼, 앞서 배운 RNN을 가장 단순한 형태의 RNN이라고 하여 바닐라 RNN(Vanilla RNN)이라고 합니다. (…

wikidocs.net

https://deeplearning.cs.cmu.edu/F23/document/readings/LSTM.pdf

LSTM 논문

 

Vanilla RNN의 문제

순환형 신경망 중 가장 기초적인 신경망을 Vanilla RNN (바닐라 RNN) 이라고 한다.

https://tiabet0929.tistory.com/54

 

[NLP Study] - RNN

트랜스포머가 무엇이 대단한지를 이해하려면, Seq2Seq부터 이해해야 하고, 결국엔 그 전의 자연어 처리가 어떠한 식으로 이루어졌는지를 완전히 이해해야 할 것 같다. 그래서 한 달 동안 RNN부터

tiabet0929.tistory.com

바닐라 RNN엔 심각한 문제가 하나 있는데 바로 장기 의존성 문제, Long Term Dependency Problem 이다. RNN의 구조를 짧게 설명하자면 Input Token 하나를 Time 하나로 바라보고, 앞의 Token이 뒤의 Token의 가중치 계산에 영향을 주는 구조로 이루어져 있다. 시계열 모델로 치면 어제의 일이 오늘, 내일, 모레 등 미래에도 쭉 영향을 미치는 것으로 바라보는 것이다. 

그러나 앞에 있는 Token일수록 뒤에 있는 Token에까지 영향력은 계속하여 떨어질 수밖에 없다. 하지만 이는 시계열에서는 몰라도 앞의 단어와 맨 뒤의 단어가 큰 상관이 있을 수 있는 자연어에선 큰 약점이다. 

 

예를 들면, '도둑은 신용카드와 현금이 가득 든 지갑을 훔쳤다' 라는 문장을 보자. 이 문장의 핵심은 '도둑' 과 '훔쳤다' 이며 이 둘은 아주 밀접한 연관이 있는 단어다. 그러나 RNN에 '도둑은 신용카드와 현금이 가득 든 아름다운 지갑을' 이라는 문장을 주면, '도둑은'의 가중치를 거의 망각해버려서 '구매했다', '만들었다' 등의 엉뚱한 대답을 내뱉을 수 있다는 것이다. 바로 이런 면에서 바닐라 RNN이 기억력이 나쁘고, 장기 의존성 문제가 발생하는 것이다.

 

LSTM은 무엇이 다른가

LSTM(Long Short Term Memory)은 바닐라 RNN을 약간 개조한 모델이다. 

 

LSTM은 RNN의 기억력 부족 현상을 해결하기 위해 메모리 셀(은닉 노드)에 셀 상태라는 값을 추가하여 지워도 될 기억과 잊으면 안 될 단어를 계산하는 과정을 추가했다. 이때 RNN처럼 이 셀 상태 값은 다음 시점의 셀 상태로 넘어가 가중치 계산에 영향을 주게 된다. 위 그림에선 C가 셀 상태를 의미하게 된다.

 

이 셀 상태라는 것을 계산하기 위하여 입력 게이트, 망각 게이트가 추가되었으며 출력 층으로 보내는 역할을 하는 출력 게이트도 추가되었다. 각각의 게이트를 간단하게 살펴보고자 한다.

 

입력 게이트

입력 게이트

이전 시점의 은닉 상태를 받아와서 현재 시점의 입력 값과 함께 계산에 사용하는 것이 RNN의 변형이라는 것이 확실히 느껴진다. 이때 g와 i 두 값이 있는데, 식은 아래와 같다.

보면 i와 g 모두 각자의 가중치 행렬을 갖고 있으며 i는 시그모이드 함수를, g는 탄젠트 하이퍼볼릭 함수를 적용하여 값을 계산하는 것을 알 수 있다. 이 값들은 셀 상태를 결정하는 데에 사용된다.

 

삭제 게이트

필요 없는 정보를 삭제하는 삭제 게이트이다.

역시나 별도의 가중치 행렬을 갖고 있으며 시그모이드 함수를 사용하는 것을 알 수 있다. 이 값이 1에 가까울 수록 정보 삭제가 덜 된 것이고 0에 가까울 수록 많은 정보가 삭제된 것이라고 한다. 

 

셀 상태

삭제 게이트와 입력 게이트에서 계산된 값들이 만나게 되는 셀 상태다. 삭제 게이트의 값은 이전 셀 상태와 결합되고 입력 게이트의 두 값은 서로 결합되어 현재의 셀 상태를 결정하게 된다. 이때 결합은 벡터의 같은 위치의 값끼리 곱하는 원소별 곱을 의미한다. 

 

삭제 게이트의 결과값 f가 0이라면, 이전 셀 상태를 완전히 무시하고 현재의 입력값으로만 셀 상태를 결정하겠다는 의미가 된다. 즉, 삭제 게이트의 역할은 현재 상태를 삭제할 것인지를 정하는 것이 아닌, 바로 이전 상태를 얼마나 삭제할 것인가를 결정하는 것이다. i나 g가 0이라면, 현재 입력을 아예 반영하지 않는다는 의미이므로 현재의 입력을 삭제해버리는 효과를 발생시킨다. 이렇게 두 게이트의 값들이 조정되면서 현재의 셀 상태를 결정하게 된다.

 

출력 게이트

그런데 특이하게도 셀 상태의 값이 그대로 출력 층으로 가는 건 아니다. 셀 상태의 값은 탄젠트 하이퍼볼릭을 한 번 거친 다음, 출력 게이트에서 입력값과 이전은닉값을 시그모이드 함수로 게산한 값과 다시 합쳐서 계산한 값과 합쳐져서 드디어 현재의 은닉 상태를 완성하게 된다. 이 현재의 은닉 상태는 출력층과 다음층으로 넘어가게 된다.

 

 

이렇게 LSTM의 내용을 정리해보았다. 확실히 단어 하나 = 시점 하나 라는 개념을 이해하고 나니까 RNN과 LSTM이 더 와닿는 느낌이 든다.

 

공부하면서 궁금했던 점은 LSTM의 용량이 가중치가 많다 보니까 아주 클 것 같은데, RNN과 비교하면 어느 정도일까? 였는데, 입력 크기와 은닉 상태의 크기가 같다고 가정했을 때 LSTM 이 RNN보다 4배의 메모리를 차지하게 된다고 한다. 역시 성능이 좋은 만큼 더 많은 계산을 요구하는 건 어쩔 수 없나 보다.