challenging-america-word-ga.../run.ipynb
2023-06-29 15:17:07 +02:00

7.6 KiB

Fine tuning GPT-2

import torch

from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments

import lzma

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.__version__, device

Methods

def file_iterator(file_path):
    print(file_path, file_path.endswith(".xz"))
    if file_path.endswith(".xz"):
        with lzma.open(file_path, mode="r") as fp:
            for line in fp.readlines():
                yield line.decode("utf-8")
    else:
        with open(file_path, "r", encoding="utf-8") as fp:
            for line in fp.readlines():
                yield line


def clear_line(line):
    return line.lower().replace("\\\\n", " ").strip("\n\t ")


def prepare_training_data(dir_path):
    data_iter = file_iterator(dir_path + "/in.tsv.xz")
    expected_iter = file_iterator(dir_path + "/expected.tsv")
    new_file_path = dir_path + "/in.txt"  
    with open(new_file_path, "w", encoding="utf-8") as fp:
        for word, line in zip(expected_iter, data_iter):
            left_context = clear_line(line.split("\t")[6])
            text = left_context + " " + word.lower().strip() + "\n"
            fp.write(text)
    return new_file_path


def train(
    dataset,
    model,
    data_collator,
    batch_size,
    epochs,
    output_path,
    overwrite_output_path=False,
    save_steps=10000,
):
    training_args = TrainingArguments(
        output_dir=output_path,
        overwrite_output_dir=overwrite_output_path,
        per_device_train_batch_size=batch_size,
        num_train_epochs=epochs,
        logging_steps=save_steps,
        save_steps=save_steps,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset,
    )
    trainer.train()
    trainer.save_model()

Load & prepare data and model

training_data_path = prepare_training_data("train")
MODEL_NAME = "gpt2"
OUTPUT_PATH = "results"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(OUTPUT_PATH)

train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=training_data_path,
    block_size=128,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.save_pretrained(OUTPUT_PATH)

Train model

EPOCHS = 1
BATCH_SIZE = 32
train(
    dataset=train_dataset,
    model=model,
    data_collator=data_collator,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    output_path=OUTPUT_PATH,
    save_steps=10000
)

Inference

for file_path, lines_no in (("test-A/in.tsv.xz", 7414), ("dev-0/in.tsv.xz", 10519)):
    with open(file_path.split("/")[0] + "/out.tsv", "w", encoding="utf-8") as fp:
        print(f'Working on file: {file_path}...')
        i = 1
        missed_lines = []
        for line in file_iterator(file_path):
            print(f'\r\t{100.0*i/lines_no:.2f}% ({i}/{lines_no})', end='')
            line = clear_line(line.split("\t")[6])
            inputs = tokenizer.encode(line, return_tensors="pt").to(device)
            output = model(inputs)

            z_dist = output[0][0][-1]
            prob_dist = torch.softmax(z_dist, dim=0)
            top_k_values, top_k_indices = prob_dist.topk(20)

            probs = []
            result = [
                (
                    tokenizer.decode(idx).strip(),
                    probs.append(prob) or prob if prob <= 0.7 else 0.7,
                )
                for prob, idx in zip(top_k_values, top_k_indices)
            ]
            result = (
                " ".join(f"{pair[0]}:{pair[1]}" for pair in result)
                + f" :{1. - sum(probs)}\n"
            )
            if len(result) < 250:
                missed_lines.append(i)
                result = "the:0.5175086259841919 and:0.12364283204078674 ,:0.05142376944422722 of:0.03426751121878624 .:0.028525719419121742 or:0.02097073383629322 :0.014924607239663601 every:0.008976494893431664 each:0.008128014393150806 a:0.007482781074941158 ;:0.005168373696506023 -:0.004823171999305487 holy:0.004624966997653246 one:0.004140088334679604 tho:0.003332334803417325 only:0.0030411879997700453 that:0.002834469312801957 !:0.0022952412255108356 ):0.002251386409625411 t:0.0021530792582780123 :0.14948463439941406\n"
            fp.write(result)
            i += 1 
        print("\t...processing finished\n\tMissed lines:", missed_lines)