paranormal-or-skeptic-ISI-p.../generate.py
wangobango 43dbf81d83 change
2021-06-22 14:03:36 +02:00

58 lines
2.3 KiB
Python

import pandas as pd
from transformers import BertTokenizer, AdamW, AutoModelForSequenceClassification
import torch
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
import torch.nn as nn
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from model import BERT_Arch
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor([self.labels[idx]])
return item
def __len__(self):
return len(self.labels)
device = torch.device('cuda')
train_texts = \
pd.read_csv('train/in.tsv.xz', compression='xz', sep='\t', header=None, error_bad_lines=False, quoting=3)[0].tolist()[:1000]
train_labels = pd.read_csv('train/expected.tsv', sep='\t', header=None, quoting=3)[0].tolist()[:1000]
dev_texts = pd.read_csv('dev-0/in.tsv.xz', compression='xz', sep='\t', header=None, quoting=3)[0].tolist()[:1000]
dev_labels = pd.read_csv('dev-0/expected.tsv', sep='\t', header=None, quoting=3)[0].tolist()[:1000]
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(
model_name, num_labels=len(pd.unique(train_labels))).to(device)
max_length = 512
tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)
# model.load_pretrained(model_path)
# tokenizer.load_pretrainded(model_path)
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)
valid_encodings = tokenizer(dev_texts, truncation=True, padding=True, max_length=max_length)
input_ids = torch.tensor(valid_encodings.data['input_ids'])[:100]
attention_mask = torch.tensor(valid_encodings.data['attention_mask'])[:100]
with torch.no_grad():
preds = model(input_ids.to(device), attention_mask.to(device))
preds = preds.logits.detach().cpu().numpy()
preds = np.argmax(preds, axis = 1)
print(preds)
print(classification_report(dev_labels, preds))
print(accuracy_score(dev_labels, preds))