added fine tunning

This commit is contained in:
Dawid 2021-06-22 10:32:27 +02:00
parent 59a86d7607
commit ded9a74adc
5 changed files with 293146 additions and 759 deletions

File diff suppressed because it is too large Load Diff

59
fine_tuning.py Normal file
View File

@ -0,0 +1,59 @@
import random
import torch
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
)
class DataWrapper(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def read_data(file_path):
with open(file_path) as f:
return f.readlines()
def wirte_output(file_path, data):
with open(file_path, 'w') as writer:
for result in trainer.predict(data):
writer.write(f"{str(result)}\n")
print("STEP 1 - READ DATA")
X_train = read_data('train/in.tsv')
y_train = read_data('train/expected.tsv')
X_dev = read_data('dev-0/in.tsv')
X_test = read_data('test-A/in.tsv')
print("STEP 2 - SHUFFLE")
data_train = list(zip(X_train, y_train))
data_train = random.sample(data_train, 15000)
print("STEP 3 - FINE TUNING")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
train_encodings = tokenizer([text[0] for text in data_train], truncation=True, padding=True)
train_dataset = DataWrapper(train_encodings, [int(text[1]) for text in data_train])
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
args = TrainingArguments("model")
device = torch.device("cpu")
# device = torch.device("cuda")
model.to(device)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
print("STEP 4 - WRITE OUTPUT")
wirte_output('train/out.tsv', X_train)
wirte_output('dev-0/out.tsv', X_dev)
wirte_output('test-A/out.tsv', X_test)

5
output_geval_fine.txt Normal file
View File

@ -0,0 +1,5 @@
Likelihood 0.0000
Accuracy 0.8253
F1.0 0.7472
Precision 0.7659
Recall 0.7294

File diff suppressed because it is too large Load Diff

289579
train/out.tsv Normal file

File diff suppressed because it is too large Load Diff