7.6 KiB
7.6 KiB
Fine tuning GPT-2
import torch
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments
import lzma
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.__version__, device
Methods
def file_iterator(file_path):
print(file_path, file_path.endswith(".xz"))
if file_path.endswith(".xz"):
with lzma.open(file_path, mode="r") as fp:
for line in fp.readlines():
yield line.decode("utf-8")
else:
with open(file_path, "r", encoding="utf-8") as fp:
for line in fp.readlines():
yield line
def clear_line(line):
return line.lower().replace("\\\\n", " ").strip("\n\t ")
def prepare_training_data(dir_path):
data_iter = file_iterator(dir_path + "/in.tsv.xz")
expected_iter = file_iterator(dir_path + "/expected.tsv")
new_file_path = dir_path + "/in.txt"
with open(new_file_path, "w", encoding="utf-8") as fp:
for word, line in zip(expected_iter, data_iter):
left_context = clear_line(line.split("\t")[6])
text = left_context + " " + word.lower().strip() + "\n"
fp.write(text)
return new_file_path
def train(
dataset,
model,
data_collator,
batch_size,
epochs,
output_path,
overwrite_output_path=False,
save_steps=10000,
):
training_args = TrainingArguments(
output_dir=output_path,
overwrite_output_dir=overwrite_output_path,
per_device_train_batch_size=batch_size,
num_train_epochs=epochs,
logging_steps=save_steps,
save_steps=save_steps,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()
Load & prepare data and model
training_data_path = prepare_training_data("train")
MODEL_NAME = "gpt2"
OUTPUT_PATH = "results"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(OUTPUT_PATH)
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=training_data_path,
block_size=128,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.save_pretrained(OUTPUT_PATH)
Train model
EPOCHS = 1
BATCH_SIZE = 32
train(
dataset=train_dataset,
model=model,
data_collator=data_collator,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
output_path=OUTPUT_PATH,
save_steps=10000
)
Inference
for file_path, lines_no in (("test-A/in.tsv.xz", 7414), ("dev-0/in.tsv.xz", 10519)):
with open(file_path.split("/")[0] + "/out.tsv", "w", encoding="utf-8") as fp:
print(f'Working on file: {file_path}...')
i = 1
missed_lines = []
for line in file_iterator(file_path):
print(f'\r\t{100.0*i/lines_no:.2f}% ({i}/{lines_no})', end='')
line = clear_line(line.split("\t")[6])
inputs = tokenizer.encode(line, return_tensors="pt").to(device)
output = model(inputs)
z_dist = output[0][0][-1]
prob_dist = torch.softmax(z_dist, dim=0)
top_k_values, top_k_indices = prob_dist.topk(20)
probs = []
result = [
(
tokenizer.decode(idx).strip(),
probs.append(prob) or prob if prob <= 0.7 else 0.7,
)
for prob, idx in zip(top_k_values, top_k_indices)
]
result = (
" ".join(f"{pair[0]}:{pair[1]}" for pair in result)
+ f" :{1. - sum(probs)}\n"
)
if len(result) < 250:
missed_lines.append(i)
result = "the:0.5175086259841919 and:0.12364283204078674 ,:0.05142376944422722 of:0.03426751121878624 .:0.028525719419121742 or:0.02097073383629322 :0.014924607239663601 every:0.008976494893431664 each:0.008128014393150806 a:0.007482781074941158 ;:0.005168373696506023 -:0.004823171999305487 holy:0.004624966997653246 one:0.004140088334679604 tho:0.003332334803417325 only:0.0030411879997700453 that:0.002834469312801957 !:0.0022952412255108356 ):0.002251386409625411 t:0.0021530792582780123 :0.14948463439941406\n"
fp.write(result)
i += 1
print("\t...processing finished\n\tMissed lines:", missed_lines)