서론
프로젝트를 진행하면서 구글의 BERT를 사용하여 쇼핑 리뷰의 감성을 이진 분류하는 작업을 하고 있었다.
https://yonghee.io/bert_binary_classification_naver/
작업에 굉장히 많은 도움을 받은 코드이다. 영화 리뷰데이터를 쇼핑 리뷰데이터로 바꾸고 필요한 부분만 가져다가 쓰면서 열심히 작업을 하였고, 결과까지 원활하게 도출해낼 수 있었다. 그런데 가장 중요한 부분에서 내가 무언가 착각을 하였는지 내 생각과는 다르게 전개되었다.
모델을 훈련하고 테스트한 결과 0.94의 정확도를 보였으나, 긍정-부정으로 분류한 결과를 테스트셋에 병합을 하고, 알맞게 분류가 되었는지 확인했더니 아래와 같은 결과가 나온 것이다.
총 4만개의 테스트데이터 중 무려 20077개가 오류를 범했음을 보인 것이다. Accuracy는 0.94로 높게 나오는데 실제로는 왜 절반 가량이 틀리게 확인되는지, Accuracy를 측정한 flat_accuracy 함수 (위 블로그의 원작자님이 직접 정의한 함수다.) 도 뜯어보고, tolist()나 flatten() 에서 잘못 된건가 싶어서 이것저것 해보다가 정답을 찾을 수 있었다.
test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE)
BERT 에 사용할 수 있도록 데이터를 처리하는 과정이다. 여기서 test_sampler 를 무의식적으로 RandomSampler를 사용하였는데, 이거 때문에 DataLoader에서 Batch의 순서가 뒤죽박죽이 되어버렸고, 이 뒤죽박죽이 된 순서대로 테스트셋에 이어붙이니 절반 가량밖에 맞추지 못한 것이었다.
BATCH_SIZE = 32
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=BATCH_SIZE)
코드를 조금 더 꼼꼼하게 봤다면 여기서 validation_sampler에 대해서 SequentialSampler로 선언한 것을 눈치를 챘겠으나... 별 생각 없이 지나친 내 실수였다.
그래서 이번 기회에 Pytorch의 DataLoader에서 쓰이는 여러 Sampler 들에 대해 간단하게 정리해보고자 한다.
RandomSampler
from torch.utils.data import RandomSampler
위와 같은 형식으로 불러올 수 있다. 이하 샘플러들도 모두 동일한 방식으로 로드할 수 있으므로 생략하도록 한다.
RandomSampler는 그 이름에서 알 수 있듯이, 데이터셋에서 데이터들을 랜덤하게 뽑아 배치로 만들게 한다. 배치로 만드는 과정은 DataLoader 함수가 해준다. 중요한 점은 Replacement가 False로 디폴트값으로 설정되어 있어 이미 뽑힌 데이터는 사용하지 않는다는 점이다. 그래서 중복을 원한다면 Replacement 파라미터를 True로 설정해주어야 한다. 그리고 Random하게 섞는 과정은 1epoch 이 진행될 때마다 진행된다고 한다.
그렇다면 RandomSampler를 사용하는 이유는 뭘까? 여러 레퍼런스를 찾아본 결과 데이터의 순서에 따라 학습 과정에서 편향이 생길 수도 있기 때문이라는 결론을 얻었다. 그래서 많은 분들이 Train 과정에선 RandomSampler, Validation이나 Test 과정에선 SequentialSampler를 사용하는 것으로 보인다.
SequentialSampler
내가 문제를 해결할 수 있었던 것이 바로 Test DataLoader에 대해 RandomSampler를 SequentialSampler로 바꾸었기 때문이다. 이는 Sequential이라는 의미 답게 데이터셋에서 순서대로 데이터를 뽑아 DataLoader에서 배치로 만든다. 나같이 모델의 결과로 데이터셋에 작업을 해줘야 할 땐 Random보다 SequentialSampler를 사용해줘야 할 것이다.
여기서부턴 Pytorch의 레퍼런스 문서를 읽고 찾아본 Sampler들이다. 정확하게 알고 싶으신 분들은 공식 문서를 읽어보시면 될 것 같다.
https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
SubsetRandomSampler
subset_indices = [2, 5, 7, 10, ...] # Specify the indices you want to include
subset_random_sampler = SubsetRandomSampler(subset_indices)
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=subset_random_sampler)
위와같이 로드할 수 있으며, 데이터셋에서 특정 순서가 중요할 때 사용하는 Sampler다. 사실 왜 필요한지는 잘 모르겠다. 홀수 번째의 데이터, 짝수 번째의 데이터만 사용하고 싶은 경우가 있으면 유용할 것 같다. 예시 코드는 ChatGPT를 통해 생성했다.
WeightedRandomSampler
weights = [0.1, 0.5, 0.3, ...] # Specify the weights for each sample
weighted_random_sampler = WeightedRandomSampler(weights, num_samples=len(weights))
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=weighted_random_sampler)
데이터의 순서에 가중치를 주고 싶을 때 사용하는 Sampler다. 이 역시 언제 사용하는 지는 잘 모르겠다. 다만 가중치를 설정해놓으면, 가중치가 큰 쪽을 우선으로 데이터를 뽑아낸다. 가중치의 합은 1일 필요도 없고, replacement 역시 설정할 수 있다.
BatchSampler
custom_batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
data_loader = DataLoader(dataset, batch_sampler=custom_batch_sampler)
BatchSampler 는 특이하게 애당초 Batch 형식으로 데이터를 뽑아내주는 Sampler다. 따라서 batch_size를 DataLoader에서 선언하지 않은 것을 볼 수 있다. drop_last를 True로 하게 되면, 데이터셋의 크기가 10이고 batch_size가 3인 경우, 1이 남게 되는데 이를 drop하느냐 마냐를 설정하는 것이다.
이렇게 PyTorch의 여러 Sampler들을 정리해보았다.
'딥러닝' 카테고리의 다른 글
[딥러닝] 역전파를 단 한 줄로 가능하게 해주는 backward() 함수 탐구 (1) | 2024.11.20 |
---|---|
[딥러닝] 생성형 AI (LLM) 에서 Loss는 어떻게 계산될까 (0) | 2024.06.22 |
[딥러닝] 활성화 함수 정리 (ReLU, softmax) (0) | 2024.05.05 |
Colab 에서 cuda error: device-side assert triggered 등 CUDA error 해결 (huggingface 관련) (0) | 2023.09.30 |
Trainer API 에서 compute_metrics 사용할 때 CUDA out of memory 해결법 (0) | 2023.09.05 |