ekstrakcja-bert/bert_infer.py
2021-06-22 22:15:01 +02:00

36 lines
1016 B
Python

import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizerFast
model_path = "bert-base-uncased-2k"
max_length = 512
DEV = 'dev-0'
TEST = 'test-A'
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2).to("cuda")
tokenizer = BertTokenizerFast.from_pretrained(model_path)
def get_prediction(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to("cuda")
outputs = model(**inputs)
return outputs[0].softmax(1).argmax()
def get_predictions_for(dataset):
test = pd.read_csv(f'{dataset}/in.tsv.xz', compression='xz', sep='\t',
error_bad_lines=False, header=None, quoting=3)[0].tolist()
test_infers = []
for row in test:
test_infers.append(get_prediction(row))
with open(f'{dataset}/out.tsv', 'w') as file:
for infer in test_infers:
file.write(str(infer.item()) + '\n')
get_predictions_for(DEV)
get_predictions_for(TEST)