60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
|
import random
|
||
|
import torch
|
||
|
from transformers import (
|
||
|
AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||
|
)
|
||
|
|
||
|
class DataWrapper(torch.utils.data.Dataset):
|
||
|
def __init__(self, encodings, labels):
|
||
|
self.encodings = encodings
|
||
|
self.labels = labels
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||
|
item['labels'] = torch.tensor(self.labels[idx])
|
||
|
return item
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.labels)
|
||
|
|
||
|
def read_data(file_path):
|
||
|
with open(file_path) as f:
|
||
|
return f.readlines()
|
||
|
|
||
|
def wirte_output(file_path, data):
|
||
|
with open(file_path, 'w') as writer:
|
||
|
for result in trainer.predict(data):
|
||
|
writer.write(f"{str(result)}\n")
|
||
|
|
||
|
print("STEP 1 - READ DATA")
|
||
|
X_train = read_data('train/in.tsv')
|
||
|
y_train = read_data('train/expected.tsv')
|
||
|
X_dev = read_data('dev-0/in.tsv')
|
||
|
X_test = read_data('test-A/in.tsv')
|
||
|
|
||
|
print("STEP 2 - SHUFFLE")
|
||
|
data_train = list(zip(X_train, y_train))
|
||
|
data_train = random.sample(data_train, 15000)
|
||
|
|
||
|
|
||
|
print("STEP 3 - FINE TUNING")
|
||
|
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
||
|
|
||
|
train_encodings = tokenizer([text[0] for text in data_train], truncation=True, padding=True)
|
||
|
train_dataset = DataWrapper(train_encodings, [int(text[1]) for text in data_train])
|
||
|
|
||
|
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
|
||
|
args = TrainingArguments("model")
|
||
|
|
||
|
device = torch.device("cpu")
|
||
|
# device = torch.device("cuda")
|
||
|
model.to(device)
|
||
|
|
||
|
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||
|
trainer.train()
|
||
|
|
||
|
print("STEP 4 - WRITE OUTPUT")
|
||
|
wirte_output('train/out.tsv', X_train)
|
||
|
wirte_output('dev-0/out.tsv', X_dev)
|
||
|
wirte_output('test-A/out.tsv', X_test)
|