paranormal-or-skeptic-ISI-p.../predict.py
Wojciech Jarmosz 91f262735e Fixes
2021-06-21 23:49:19 +02:00

60 lines
1.9 KiB
Python

from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
import random
import torch
import numpy as np
with open('train/in.tsv') as f:
data_train_X = f.readlines()
with open('train/expected.tsv') as f:
data_train_Y = f.readlines()
with open('dev-0/in.tsv') as f:
data_dev_X = f.readlines()
with open('test-A/in.tsv') as f:
data_test_X = f.readlines()
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model_path = "model/checkpoint-1500"
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)
trainer = Trainer(model)
with open('train/out.tsv', 'w') as writer:
train_encodings = tokenizer(data_train_X, truncation=True, padding=True)
train_dataset = CustomDataset(train_encodings)
raw_pred, _, _ = trainer.predict(train_dataset)
for result in np.argmax(raw_pred, axis=1):
writer.write(str(result) + '\n')
with open('dev-0/out.tsv', 'w') as writer:
dev_encodings = tokenizer(data_dev_X, truncation=True, padding=True)
dev_dataset = CustomDataset(dev_encodings)
raw_pred, _, _ = trainer.predict(dev_dataset)
for result in np.argmax(raw_pred, axis=1):
writer.write(str(result) + '\n')
with open('test-A/out.tsv', 'w') as writer:
test_encodings = tokenizer(data_test_X, truncation=True, padding=True)
test_dataset = CustomDataset(test_encodings)
raw_pred, _, _ = trainer.predict(test_dataset)
for result in np.argmax(raw_pred, axis=1):
writer.write(str(result) + '\n')