0. 요약
그냥 제가 같은 작업 자꾸자꾸 반복하는 거 좀 체계화/단일화 하기 위해서 순서 정리하는 글이에요. "다른 사람은 어떻게 하나"를 알기가 힘들고, 특히 저와 같이 한국에서 명문대를 안 나온 사람일수록 다른 사람 연구 방법론을 귀동냥하기도 힘들어요 (오히려 그런 사람들이 더 많이 귀동냥을 해야 하는데도 말이죠). 그래서 이 글은 그런 사람들한테 도움이 되었으면 하는 바람에서 씁니다.
목차
1. 단계들
1단계: 시작하기
일단 찬물을 마셔서 정신을 차리고, 모든 걸 담을 수 있는(?) 폴더를 하나 만든다. 사람이 준비 안 되어 있으면 기계는 일하지 않는다. 폴더에는 일단은 raw data만 담아둔다.
raw data라고 함은, 입력과 출력 쌍을 말한다. 예를 들어 한국어의 위치동화를 학습하는 모델을 만들고자 한다면, 입력 데이터는 "눈물, 국물, 한글 ...." 이 될 것이고 출력 데이터는 "눔물, 궁물, 항글 ..."이 될 것이다. 물론 음운론 모델은 소리의 추상적 표상형태를 학습의 대상으로 삼기 때문에 IPA나 기타 전사기호를 사용하고 한글을 사용하지 않는다.
2단계: 모델 훈련하기
Google colab이든 Amazon AWS든 Microsoft Azure든 본인 취향에 맞는 cloud computing resources 에서 모델을 훈련한다. 나는 아직 모델 훈련 전단계가 익숙하지 않기 때문에 raw data 읽어들이거나 preprocessing할 때 중간중간 printout 할 수 있게 Jupyter notebook을 선호한다. 예시: 1https://colab.research.google.com/drive/1nu6lU1qNl1OjUSMuUJKMKnVuS_L35Sxz?usp=sharing (혹은 더 원천적인?? 예시는 https://colab.research.google.com/drive/1N6SncVXUe8dtdQm3EzjStOApkz3Wq0ma
이때 best performing checkpoint 10개? 20개? 정도만 뽑아낸다. Fairseq은 validation accuracy에 따른 evaluation 제공하지 않는데 loss에 따르는 것이 그나마 만족할 만하다. 하지만 궁극적으로는 valid accuracy로 evaluate하고 최종모델 결정해야 하기 때문에, 일단 loss 기준으로 best performing 10개 정도 뽑아두고 그것만을 대상으로 accuracy 계산한다.
모델 훈련이 잘 되었나, overfit이나 underfit이 발생하지 않았나, 등등 훈련 과정을 기술하려면 loss curve를 그려야 한다.
뭔가 자동으로 해주는 솔루션이 있긴 할텐데, 어짜피 fairseq에서 export해주는 logs에 모든 정보가 담겨있기 때문에 내 입맛에 맞게 아래 스크립트로 loss curve를 그렸다. 이 스크립트에는 최종 x가 epoch만 대상으로 loss curve 그려주는 기능이 있다. 또 loss 기준 최고의 epoch 20개 출력하기도 하는데, 이건 사실상 쓸모없다. 왜냐하면 training command 자체에 최고 epoch 개수 parameter가 있기 때문이다.
# quickly visualize validation and training loss from fairseq training log
import json
from collections import OrderedDict
import matplotlib.pyplot as plt
def parse_json(line, key_to_find):
json_data = json.loads(line)
return json_data.get(key_to_find)
def parse_file(logs: str) -> tuple:
n = 1
logs = logs.split('\n')
valid_losses = []
train_losses = []
valid_best = {}
flag = [False, False] # train_loss and valid_loss info duplicate. flag to check getting info. if true, ignore subsequent duplicates
for line in logs:
if f'"epoch": {n}, "valid_loss":' in line:
if not flag[0]:
valid_loss = parse_json(line, 'valid_loss')
valid_loss = float(valid_loss)
valid_losses.append(valid_loss)
this_epoch_val_best = parse_json(line, 'valid_best_loss')
try:
min_valid_best = min(valid_best.items(), key=lambda x: x[1])
except ValueError:
# ValueError when valid_best is initialized and empty
min_valid_best = {1: 1}
if this_epoch_val_best is not None:
this_epoch_val_best = float(this_epoch_val_best)
if this_epoch_val_best < min_valid_best[1]:
valid_best[n] = this_epoch_val_best
flag[0] = True
elif f'"epoch": {n}, "train_loss":' in line:
if not flag[1]:
train_loss = parse_json(line, 'train_loss')
train_loss = float(train_loss)
train_losses.append(train_loss)
flag[1] = True
if all(flag):
flag = [False, False]
n += 1
return train_losses, valid_losses, valid_best
def show(total_n, train_loss: list, valid_loss: list) -> None:
index = range(total_n-len(train_loss), total_n)
plt.plot(index, train_loss, label='training loss')
plt.plot(index, valid_loss, label='validation loss')
# Adding labels and legend
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Display the plot
plt.show()
def main():
file_path = '[your path]'
# Display the selected file path
if file_path:
with open(file_path, 'r') as file:
logs = file.read()
train_losses, valid_losses, valid_best = parse_file(logs)
print('Epoch\ttrain_loss\tvalid_loss')
for i, losses in enumerate(zip(train_losses, valid_losses), 1):
print(f'Epoch {i}:\t {losses[0]}\t{losses[1]}')
print('\n Best 20 checkpoints')
print('Epoch\tvalid_loss')
valid_best_twenty = sorted(valid_best.items(), key=lambda x: x[1])[:20]
for epoch_n, val_loss in valid_best_twenty:
print(f'Epoch {epoch_n}:\t{val_loss}')
while True:
trunc = input(f"Need trunc (# of epoch: {len(train_losses)})\n 'Q' to end? ")
if trunc.lower() == 'q':
break
elif trunc != '':
try:
num_trunc = int(trunc)
train_losses_to_plot = train_losses[-num_trunc:]
valid_losses_to_plot = valid_losses[-num_trunc:]
show(len(train_losses), train_losses_to_plot, valid_losses_to_plot)
except ValueError:
print("Please enter a number")
else:
print("No file selected")
if __name__ == '__main__':
main()
loss curves와 hyperparameter setting 하는 것들은 기술적인 부분이고, 컴퓨터과학 컴퓨터공학 전공자 분들이 이 부분에 대해서는 여기저기에서 더 자세히 설명하신다. NLP를 포함한 머신러닝 전반의 분위기는 기본적인 테크닉과 방법론을 숨기지 않고 적극적으로 공유하는 것이다. 튜토리얼들도 여기저기에 참 많이 있다. 그러니까 적극적으로 찾아보면 배울 수 있다.
3단계: validation accuracy 구하기
Computing resource는 비싸기 때문에 로컬로 valid accuracy 연산을 한다.
돈 얘기가 나온김에, 되도록이면 cloud computing resource 사용 최소한으로 하기 위해 아예 전처리까지 다 로컬로 하는 것도 고려할 만하다. 그러나 training 자체를 로컬로 돌리기에는 너무 시간과 노력이 아깝다. (물론 노트북이나 GPU 빵빵한 cloud server나 튜링완전 측면에선 동등하고 속도차이만 있을 뿐이므로 궁극적으로 결과는 동일하게 나온다. 결국 지갑의 화폐를 지불하느냐 시간을 지불하느냐의 문제다.) 2
어쨌든, 모델이랑 evaluation data 읽어들인 후 model.translate()
돌린다음 accuracy metric 계산하는 적절한 Python script 짜서 evaluation accuracy 구하면 된다. 아래는 예시. 내가 맨날 헷갈려서 코멘트를 잘 붙여놨네. 과거의 나를 칭찬하는 바이다.
import os
import csv
from fairseq.models.transformer import TransformerModel
DATA_BIN = os.path.join(CWD, 'bin')
MODEL = os.path.join(CWD, 'model_output_transformer')
VAL_ITEMS = os.path.join(CWD, 'dev.ur-sr.ur')
VAL_GOLD = os.path.join(CWD, 'dev.ur-sr.sr')
EXPORT_PATH = os.path.join(CWD, 'validation_accuracy.csv')
def accuracy(predictions, targets):
# Ensure that the number of predictions and targets match
if len(predictions) != len(targets):
raise ValueError(f"Number of predictions and targets must be the same. "
f"\nTargets #: {len(targets)} \nPredictions #: {len(predictions)}")
# Count the number of correct predictions
correct_predictions = sum(p == t for p, t in zip(predictions, targets))
# Calculate accuracy as a percentage
accuracy = (correct_predictions / len(targets)) * 100.0
return accuracy
def export_csv(list_tuples, predictions=False, export_path=EXPORT_PATH):
# list_tuples: list of tuples
# predictions: Bool. True when exporting predictions of an epoch. Otherwise False
dir = os.path.dirname(export_path)
if not os.path.exists(dir):
os.makedirs(dir)
print(f"Created {dir}")
with open(export_path, 'w', newline='') as file:
writer = csv.writer(file)
header = ['epoch', 'accuracy']
if predictions:
header = ['prediction', 'target']
writer.writerow(header)
writer.writerows(list_tuples) # write data
print(f"Done exporting {export_path}")
def main(MODEL=MODEL, export_translation=False):
checkpoints = os.listdir(MODEL) # get the list of .pt files in the model directory
r = [] # container for (checkpoint_name, accuracy_score)
# load the validation items
with open(VAL_ITEMS, 'r') as file:
items = [line.strip() for line in file]
# load the gold standard
with open(VAL_GOLD, 'r') as file:
targets = [line.strip() for line in file]
for checkpoint_path in checkpoints:
checkpoint, _ = os.path.splitext(checkpoint_path)
ur2sr = TransformerModel.from_pretrained(
MODEL,
checkpoint_file=checkpoint_path,
data_name_or_path=DATA_BIN
)
predictions = []
for item in items:
predictions.append(ur2sr.translate(item))
pred_target = list(zip(predictions, targets)) # tuple
export_csv(pred_target, predictions=True, export_path=predictions_path)
accuracy_score = accuracy(predictions, targets)
r.append((checkpoint, accuracy_score))
export_csv(r)
if __name__ == '__main__':
main()
4단계: best model의 attention 뽑아내기
3단계 결과를 가지고 최고의 성능을 보여준 모델을 선정한다. 이 모델은 특정 관심현상이 아니라 음운부 그자체를 대표(represent)하는 언어모델이다.
자 이제 그 모델한테 "왜! 이 규칙을 어떠어떠하게 적용했나요?" 라고 물어본다.......가 가능할리가. 기계한테 자연 언어로 물어보면 대답을 안 해준다. 한국어든 영어든 마찬가지다. 사실 쉽게 쉽게 말로 해서는 무엇도 얻을 수 없다. 그래서 모델을 고문해야한다.
고문하는 방식은 다른 글에서 소개했다.
5단계: attention에 대한 통계적 분석
개별 단어가 아니라 많은 단어에서 동일한 attention pattern을 보여야 음운론적으로 유의미하다. 그니까 통계처리해야한다.
예를 들어 '학문', '거짓말', '밥물'을 [항문], [거진말], [밤물] 등으로 발음하는데, 이렇게 음절말음 비음화를 할 때, 후행자음이 영향을 준다. 신경망을 이용한 연구 방법을 이용해서, 정말로 이처럼 선행연구(그리고 일반적 상식)처럼 후행자음의 비음성이 영향을 주는 건지 확인하려면, 모델의 cross-attention이 학문의 ㅁ, 거짓말의 ㅁ, 그리고 밥물의 ㅁ에 집중되어 있다는 걸 보여야 한다.
2. 결론
딱히 결론이 없다. 당연하지만 Attention is all you need 이다.
그러니까 attention에 집중하자.
그러니까 뉴진스 노래가 생각난다. You got me looking for attention!
그러니까 결론은 뉴진스 화이팅
- 아래에 댓글창이 열려있습니다. 로그인 없이도 댓글 다실 수 있습니다.
- 글과 관련된 것, 혹은 글을 읽고 궁금한 것이라면 무엇이든 댓글을 달아주세요.
- 반박이나 오류 수정을 특히 환영합니다.
- 로그인 없이 비밀글을 다시면, 거기에 답변이 달려도 보실 수 없습니다. 답변을 받기 원하시는 이메일 주소 등을 비밀글로 남겨주시면 이메일로 답변드리겠습니다.
'Bouncing ideas 생각 작업실 > exp sharing 경험.실험 나누기' 카테고리의 다른 글
ChatGPT는 Praat Script 짤줄 몰라 (0) | 2024.05.27 |
---|---|
음성 데이터만 있어도 분석해버리기 (0) | 2024.05.14 |
fairseq translation task cross-attention 접근 쉽게하기 (0) | 2024.04.10 |
Never assume anything (0) | 2024.04.02 |
Python으로 textgrid 생성했는데 왜 먹지를 못하니 (0) | 2024.03.22 |