paranormal-or-skeptic-ISI-p.../fine_tuning.py

47 lines
1.4 KiB
Python
Raw Normal View History

2021-06-21 23:49:19 +02:00
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
2021-06-21 01:42:25 +02:00
import random
import torch
with open('train/in.tsv') as f:
data_train_X = f.readlines()
with open('train/expected.tsv') as f:
data_train_Y = f.readlines()
with open('dev-0/in.tsv') as f:
data_dev_X = f.readlines()
with open('test-A/in.tsv') as f:
data_test_X = f.readlines()
class CustomDataset(torch.utils.data.Dataset):
2021-06-21 23:49:19 +02:00
def __init__(self, encodings, labels=None):
2021-06-21 01:42:25 +02:00
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
2021-06-21 23:49:19 +02:00
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
2021-06-21 01:42:25 +02:00
return item
def __len__(self):
2021-06-21 23:49:19 +02:00
return len(self.encodings["input_ids"])
2021-06-21 01:42:25 +02:00
data_train = list(zip(data_train_X, data_train_Y))
2021-06-22 00:28:06 +02:00
data_train = random.sample(data_train, 50000)
2021-06-21 01:42:25 +02:00
2021-06-21 23:49:19 +02:00
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_X = tokenizer([text[0] for text in data_train], truncation=True, padding=True)
train_Y = [int(text[1]) for text in data_train]
2021-06-21 01:42:25 +02:00
2021-06-21 23:49:19 +02:00
train_dataset = CustomDataset(train_X, train_Y)
2021-06-21 01:42:25 +02:00
2021-06-21 23:49:19 +02:00
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
2021-06-21 01:42:25 +02:00
2021-06-21 23:49:19 +02:00
training_args = TrainingArguments("model")
2021-06-21 01:42:25 +02:00
trainer = Trainer(
model=model, args=training_args, train_dataset=train_dataset)
trainer.train()