added fine tunning
This commit is contained in:
parent
59a86d7607
commit
ded9a74adc
1518
dev-0/out.tsv
1518
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
59
fine_tuning.py
Normal file
59
fine_tuning.py
Normal file
@ -0,0 +1,59 @@
|
||||
import random
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||
)
|
||||
|
||||
class DataWrapper(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings, labels):
|
||||
self.encodings = encodings
|
||||
self.labels = labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||
item['labels'] = torch.tensor(self.labels[idx])
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def read_data(file_path):
|
||||
with open(file_path) as f:
|
||||
return f.readlines()
|
||||
|
||||
def wirte_output(file_path, data):
|
||||
with open(file_path, 'w') as writer:
|
||||
for result in trainer.predict(data):
|
||||
writer.write(f"{str(result)}\n")
|
||||
|
||||
print("STEP 1 - READ DATA")
|
||||
X_train = read_data('train/in.tsv')
|
||||
y_train = read_data('train/expected.tsv')
|
||||
X_dev = read_data('dev-0/in.tsv')
|
||||
X_test = read_data('test-A/in.tsv')
|
||||
|
||||
print("STEP 2 - SHUFFLE")
|
||||
data_train = list(zip(X_train, y_train))
|
||||
data_train = random.sample(data_train, 15000)
|
||||
|
||||
|
||||
print("STEP 3 - FINE TUNING")
|
||||
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
||||
|
||||
train_encodings = tokenizer([text[0] for text in data_train], truncation=True, padding=True)
|
||||
train_dataset = DataWrapper(train_encodings, [int(text[1]) for text in data_train])
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
|
||||
args = TrainingArguments("model")
|
||||
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda")
|
||||
model.to(device)
|
||||
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
print("STEP 4 - WRITE OUTPUT")
|
||||
wirte_output('train/out.tsv', X_train)
|
||||
wirte_output('dev-0/out.tsv', X_dev)
|
||||
wirte_output('test-A/out.tsv', X_test)
|
5
output_geval_fine.txt
Normal file
5
output_geval_fine.txt
Normal file
@ -0,0 +1,5 @@
|
||||
Likelihood 0.0000
|
||||
Accuracy 0.8253
|
||||
F1.0 0.7472
|
||||
Precision 0.7659
|
||||
Recall 0.7294
|
2744
test-A/out.tsv
2744
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
289579
train/out.tsv
Normal file
289579
train/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user