본문 바로가기
딥러닝

Trainer API 에서 compute_metrics 사용할 때 CUDA out of memory 해결법

by Tiabet 2023. 9. 5.

간밤에 KoGPT 를 사용해서 Colab에서 koGPT 를 파인튜닝하고 있었는데, 자꾸만 아래 오류가 발생했다.

처음에는 배치 사이즈나 tokenizer 과정에서 문제가 생겼나 싶어서 이 부분을 고쳐봤으나 인터넷에서 찾을 수 있는 모든 자료를 다 시도해봐도 오류를 벗어날 수 없었다.

 

내가 시도해본 방법은

1) 커널 재시작

2) batch_size 1까지도 줄여서 적용

3) Garbage Collect

import gc
gc.collect()

4) 캐시 청소

torch.cuda.empty_cache()

5)

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:<32000>"

 

이 정도 방법들이었는데 한국어로 찾아볼 수 있는 거의 모든 방법을 시도해봤으므로 그냥 모델 용량이 커서 어쩔 수 없나? 하는 생각이 들었다.

 

그러다 Trainer API 내에서 evaluation이 작동하다가 오류가 발생하는 것을 발견하였고, 코랩 환경은 전세계에서 널리 사용하기 때문에 비슷한 문제가 분명히 다른 사람에게도 나타났을 것이라 생각했다.

 

그렇게 열심히 구글링을 한 결과, 아래 자료에서 답을 얻을 수 있었다.

 

https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/1

 

CUDA out of memory when using Trainer with compute_metrics

Recently, I want to fine-tuning Bart-base with Transformers (version 4.1.1). The fine-tuning process is very smooth with compute_metrics=None in Trainer. However, when I implement a function of computing metrics and offer this function to Trainer, I receiv

discuss.huggingface.co

 

정확히 나랑 같은 상황이었다. Trainer 내에 compute_metrics 함수를 집어넣어 evaluation을 시도하다 나타난 오류였다. 댓글에서도 많은 사람들이 애를 먹고 있었는데 한 분이 해결법을 제시해주셨다.

 

중요한 부분은 preprocess_logis_for_metrics 함수다. 설명은 허깅페이스의 Trainer 설명 부분에 잘 나와있다.

https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.preprocess_logits_for_metrics

 

Trainer

When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples.

huggingface.co

간단히 정리하자면, 평가 단계로 넘어가기 전에 compute_metrics 의 결과를 미리 전처리한다는 것이다. 추측하건데 데이터셋의 크기가 크면 logit의 크기도 굉장히 클 수밖에 없는데, preprocess_logis_for_metrics 를 적용하지 않으면 이 모든 logit 을 gpu로 보내버린 뒤 처리하여 gpu의 용량을 많이 잡아먹는 것으로 보인다. 그래서 gpu로 넘기기 전, 가장 값이 큰 logit을 걸러내어 이것과 labels 만 gpu로 보내는 것이다. 이러면 logits의 용량을 아주 크게 줄일 수 있어 gpu의 부담을 덜어주는 것으로 보인다.

 

def preprocess_logits_for_metrics(logits, labels):

    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics = evaluate,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics
)

이렇게 코드를 수정한 끝에 오류를 해결할 수 있었다 ! 한글로 이런 자료를 정리해놓으신 분이 없는 것 같아서 짧게라도 공유해본다.