ireland-news-headlines/roberta_no_year/03_train.py

77 lines
2.1 KiB
Python
Raw Permalink Normal View History

2021-07-11 08:34:28 +02:00
import pickle
from config import LABELS_LIST, MODEL
with open('train_dataset.pickle','rb') as f_p:
train_dataset = pickle.load(f_p)
2021-09-18 09:20:14 +02:00
with open('eval_dataset_small.pickle','rb') as f_p:
eval_dataset_small = pickle.load(f_p)
2021-07-11 08:34:28 +02:00
with open('eval_dataset_full.pickle','rb') as f_p:
eval_dataset_full = pickle.load(f_p)
with open('test_dataset.pickle','rb') as f_p:
test_dataset = pickle.load(f_p)
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=7)
from transformers import TrainingArguments
training_args = TrainingArguments("test_trainer",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
evaluation_strategy='steps',
2021-09-18 09:20:14 +02:00
#eval_steps=2_000,
#save_steps=2_000,
eval_steps=20_000,
save_steps=20_000,
num_train_epochs=1,
gradient_accumulation_steps=2,
2021-07-11 08:34:28 +02:00
learning_rate = 1e-6,
2021-09-18 09:20:14 +02:00
#warmup_steps=4_000,
warmup_steps=4,
2021-07-11 08:34:28 +02:00
load_best_model_at_end=True,
)
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
2021-09-18 09:20:14 +02:00
eval_dataset=eval_dataset_small,
2021-07-11 08:34:28 +02:00
compute_metrics=compute_metrics,
)
2021-09-18 09:20:14 +02:00
#trainer.train(resume_from_checkpoint=True)
2021-07-11 08:34:28 +02:00
trainer.train()
trainer.save_model("./roberta-retrained")
trainer.evaluate()
eval_predictions = trainer.predict(eval_dataset_full).predictions.argmax(1)
with open('../dev-0/out.tsv', 'w') as f_out:
for pred in eval_predictions:
f_out.write(LABELS_LIST[pred] + '\n')
test_predictions = trainer.predict(test_dataset).predictions.argmax(1)
with open('../test-A/out.tsv', 'w') as f_out:
for pred in test_predictions:
f_out.write(LABELS_LIST[pred] + '\n')