Fix script

This commit is contained in:
nlitkowski 2021-06-22 20:27:03 +02:00
parent 4b000457df
commit 567f498ee2

28
main.py
View File

@ -2,20 +2,16 @@ import os
import sys
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
try:
import lzma
except ImportError:
from backports import lzma
import random
IN_FILE_NAME = "in.tsv.xz"
IN_FILE_NAME = "in.tsv"
OUT_FILE_NAME = "out.tsv"
TRAIN_PATH = "train"
EXP_FILE_NAME = "expected.tsv"
FILE_SEP = "\t"
# PT_MODEL_NAME = "bert-base-cased"
PT_MODEL_NAME = "roberta-base"
DEVICE = "cpu"
MODEL_OUT_NAME = "./model.tr"
class CustomDataset(torch.utils.data.Dataset):
@ -50,15 +46,19 @@ def main(dirnames):
train_data = list(zip(train_set_features, train_set_labels))
train_data = random.sample(train_data, 15000)
tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME)
mname = PT_MODEL_NAME
if os.path.exists(MODEL_OUT_NAME):
mname = MODEL_OUT_NAME
tokenizer = AutoTokenizer.from_pretrained(mname)
model = AutoModelForSequenceClassification.from_pretrained(
PT_MODEL_NAME, num_labels=2)
mname, num_labels=2)
train_set_enc = tokenizer(
[text[0] for text in train_data], truncation=True, padding=True)
ds = CustomDataset(
train_set_enc, [int(text[1]) for text in train_data])
device = torch.device(DEVICE)
device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
trainer = Trainer(
@ -70,6 +70,7 @@ def main(dirnames):
print("Starting training...")
trainer.train()
trainer.save_model(MODEL_OUT_NAME)
print("Predicting outputs...")
@ -84,12 +85,9 @@ def main(dirnames):
def get_tsv_data(filename: str, compressed=False):
if compressed:
with lzma.open(filename=filename) as f:
return f.readlines()
else:
with open(filename) as f:
return f.readlines()
check_path(filename=filename)
with open(filename) as f:
return f.readlines()
def check_path(filename: str):