challenging-america-word-ga.../run.ipynb
2023-06-29 12:28:49 +02:00

8.0 KiB
Raw Blame History

Fine tuning GPT-2

Model dotrenowano z wykorzystaniem odwróconego ciągu tokenów i odgadywanych słów (od prawej do lewej) f"{word} {right_context}".split()[::-1]] ignorując lewy kontekst.

https://gonito.net/view-variant/9580

import torch

from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments, PreTrainedModel

import lzma

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.__version__, device

Methods

def reverse_sentence(sentence):
    return " ".join(sentence.split()[::-1])


def file_iterator(file_path):
    print(file_path, file_path.endswith(".xz"))
    if file_path.endswith(".xz"):
        with lzma.open(file_path, mode="r") as fp:
            for line in fp.readlines():
                yield line.decode("utf-8")
    else:
        with open(file_path, "r", encoding="utf-8") as fp:
            for line in fp.readlines():
                yield line


def clear_line(line):
    return line.lower().replace("\\\\n", " ").strip("\n\t ")


def prepare_training_data(dir_path):
    data_iter = file_iterator(dir_path + "/in.tsv.xz")
    expected_iter = file_iterator(dir_path + "/expected.tsv")
    new_file_path = dir_path + "/in.txt"  
    with open(new_file_path, "w", encoding="utf-8") as fp:
        for word, line in zip(expected_iter, data_iter):
            left_context = clear_line(line.split("\t")[6])
            text = left_context + " " + word.lower().strip() + "\n"
            fp.write(text)
    return new_file_path


def train(
    dataset,
    model,
    data_collator,
    batch_size,
    epochs,
    output_path,
    overwrite_output_path=False,
    save_steps=10000,
):
    training_args = TrainingArguments(
        output_dir=output_path,
        overwrite_output_dir=overwrite_output_path,
        per_device_train_batch_size=batch_size,
        num_train_epochs=epochs,
        logging_steps=save_steps,
        save_steps=save_steps,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset,
    )
    trainer.train()
    trainer.save_model()

Load & prepare data and model

!cat train/in.txt | head -n 5
training_data_path = prepare_training_data("train")
MODEL_NAME = "gpt2"
OUTPUT_PATH = "results"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(OUTPUT_PATH)

train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=training_data_path,
    block_size=128,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.save_pretrained(OUTPUT_PATH)

Train model

EPOCHS = 1
BATCH_SIZE = 32
train(
    dataset=train_dataset,
    model=model,
    data_collator=data_collator,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    output_path=OUTPUT_PATH,
    save_steps=10000
)

Inference

# model = GPT2LMHeadModel.from_pretrained('results/checkpoint-48000/')
# model
for file_path in ("test/in.tsv.xz", "dev/in.tsv.xz"):
    with open(file_path.split("/")[0] + "/out.tsv", "w", encoding="utf-8") as fp:
        for line in file_iterator(file_path):
            line = reverse_sentence(line.lower().strip("\n").replace("\\\\n", " "))
            inputs = tokenizer.encode(line, return_tensors="pt").to(device)
            output = model(inputs)

            z_dist = output[0][0][-1]
            prob_dist = torch.softmax(z_dist, dim=0)
            top_k_values, top_k_indices = prob_dist.topk(20)

            remainder = 1
            result = ""
            probs = []
            result = [
                (
                    tokenizer.decode(idx).strip(),
                    probs.append(prob) or prob if prob <= 0.7 else 0.7,
                )
                for prob, idx in zip(top_k_values, top_k_indices)
            ]
            result = (
                " ".join(f"{pair[0]}:{pair[1]}" for pair in result)
                + f" :{1. - sum(probs)}\n"
            )
            fp.write(result)