2021-06-20 22:05:07 +02:00
|
|
|
import pandas as pd
|
|
|
|
from transformers import BertTokenizer, AdamW, AutoModelForSequenceClassification
|
|
|
|
import torch
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
|
|
|
|
import torch.nn as nn
|
|
|
|
from sklearn.utils.class_weight import compute_class_weight
|
|
|
|
import numpy as np
|
|
|
|
from sklearn.metrics import classification_report
|
2021-06-22 17:16:57 +02:00
|
|
|
from sklearn.metrics import accuracy_score, f1_score
|
2021-06-22 14:03:36 +02:00
|
|
|
from transformers import BertTokenizerFast, BertForSequenceClassification
|
|
|
|
from transformers import Trainer, TrainingArguments
|
2021-06-22 17:16:57 +02:00
|
|
|
import csv
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, encodings, labels):
|
|
|
|
self.encodings = encodings
|
|
|
|
self.labels = labels
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
def __getitem__(self, idx):
|
|
|
|
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
|
|
|
|
item["labels"] = torch.tensor([self.labels[idx]])
|
|
|
|
return item
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.labels)
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 17:16:57 +02:00
|
|
|
def save_tsv_result(path, data):
|
|
|
|
with open(path, "w") as save:
|
|
|
|
writer = csv.writer(save, delimiter='\t', lineterminator='\n')
|
|
|
|
for value in [str(x) for x in data]:
|
|
|
|
writer.writerow([value])
|
|
|
|
|
|
|
|
def predictions_for_set(inputs, masks):
|
|
|
|
predictions = []
|
|
|
|
with torch.no_grad():
|
|
|
|
batch_size = 60
|
|
|
|
for i in range(0, len(inputs), batch_size):
|
|
|
|
preds = model(inputs[i: i + batch_size].to(device),
|
|
|
|
masks[i: i + batch_size].to(device))
|
|
|
|
preds = preds.logits.detach().cpu().numpy()
|
|
|
|
preds = np.argmax(preds, axis=1)
|
|
|
|
predictions += preds.tolist()
|
|
|
|
return predictions
|
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
device = torch.device('cuda')
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 17:16:57 +02:00
|
|
|
# train_texts = \
|
|
|
|
# pd.read_csv('train/in.tsv.xz', compression='xz', sep='\t',
|
|
|
|
# header=None, error_bad_lines=False, quoting=3)[0].tolist()
|
|
|
|
# train_labels = pd.read_csv(
|
|
|
|
# 'train/expected.tsv', sep='\t', header=None, quoting=3)[0].tolist()
|
|
|
|
dev_texts = pd.read_csv('dev-0/in.tsv.xz', compression='xz',
|
|
|
|
sep='\t', header=None, quoting=3)[0].tolist()
|
|
|
|
dev_labels = pd.read_csv('dev-0/expected.tsv', sep='\t',
|
|
|
|
header=None, quoting=3)[0].tolist()
|
|
|
|
test_texts = pd.read_csv('test-A/in.tsv.xz', compression='xz', sep='\t',
|
|
|
|
header=None, error_bad_lines=False, quoting=3)[0].tolist()
|
|
|
|
|
|
|
|
model_name = "bert-base-uncased-pretrained"
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
model = BertForSequenceClassification.from_pretrained(
|
2021-06-22 17:16:57 +02:00
|
|
|
model_name, num_labels=len(pd.unique(dev_labels))).to(device)
|
2021-06-22 14:03:36 +02:00
|
|
|
max_length = 512
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 14:03:36 +02:00
|
|
|
# model.load_pretrained(model_path)
|
|
|
|
# tokenizer.load_pretrainded(model_path)
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 17:16:57 +02:00
|
|
|
# train_encodings = tokenizer(
|
|
|
|
# train_texts, truncation=True, padding=True, max_length=max_length)
|
|
|
|
valid_encodings = tokenizer(
|
|
|
|
dev_texts, truncation=True, padding=True, max_length=max_length)
|
|
|
|
test_encodings = tokenizer(
|
|
|
|
test_texts, truncation=True, padding=True, max_length=max_length)
|
|
|
|
|
|
|
|
input_ids_val = torch.tensor(valid_encodings.data['input_ids'])
|
|
|
|
attention_mask_val = torch.tensor(valid_encodings.data['attention_mask'])
|
|
|
|
|
|
|
|
input_ids_test = torch.tensor(test_encodings.data['input_ids'])
|
|
|
|
attention_mask_test = torch.tensor(test_encodings.data['attention_mask'])
|
|
|
|
|
|
|
|
predictions = predictions_for_set(input_ids_val, attention_mask_val)
|
|
|
|
print("Predictions for dev set:")
|
|
|
|
print(classification_report(dev_labels, predictions))
|
|
|
|
print(accuracy_score(dev_labels, predictions))
|
|
|
|
print(f1_score(dev_labels, predictions))
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 17:16:57 +02:00
|
|
|
save_tsv_result("dev-0/out.tsv", predictions)
|
2021-06-20 22:05:07 +02:00
|
|
|
|
2021-06-22 17:16:57 +02:00
|
|
|
predictions = predictions_for_set(input_ids_test, attention_mask_test)
|
|
|
|
save_tsv_result("test-A/out.tsv", predictions)
|