Fix script

This commit is contained in:
nlitkowski 2021-06-22 20:27:03 +02:00
parent 4b000457df
commit 567f498ee2

24
main.py
View File

@ -2,20 +2,16 @@ import os
import sys import sys
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch import torch
try:
import lzma
except ImportError:
from backports import lzma
import random import random
IN_FILE_NAME = "in.tsv.xz" IN_FILE_NAME = "in.tsv"
OUT_FILE_NAME = "out.tsv" OUT_FILE_NAME = "out.tsv"
TRAIN_PATH = "train" TRAIN_PATH = "train"
EXP_FILE_NAME = "expected.tsv" EXP_FILE_NAME = "expected.tsv"
FILE_SEP = "\t" FILE_SEP = "\t"
# PT_MODEL_NAME = "bert-base-cased" # PT_MODEL_NAME = "bert-base-cased"
PT_MODEL_NAME = "roberta-base" PT_MODEL_NAME = "roberta-base"
DEVICE = "cpu" MODEL_OUT_NAME = "./model.tr"
class CustomDataset(torch.utils.data.Dataset): class CustomDataset(torch.utils.data.Dataset):
@ -50,15 +46,19 @@ def main(dirnames):
train_data = list(zip(train_set_features, train_set_labels)) train_data = list(zip(train_set_features, train_set_labels))
train_data = random.sample(train_data, 15000) train_data = random.sample(train_data, 15000)
tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME) mname = PT_MODEL_NAME
if os.path.exists(MODEL_OUT_NAME):
mname = MODEL_OUT_NAME
tokenizer = AutoTokenizer.from_pretrained(mname)
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
PT_MODEL_NAME, num_labels=2) mname, num_labels=2)
train_set_enc = tokenizer( train_set_enc = tokenizer(
[text[0] for text in train_data], truncation=True, padding=True) [text[0] for text in train_data], truncation=True, padding=True)
ds = CustomDataset( ds = CustomDataset(
train_set_enc, [int(text[1]) for text in train_data]) train_set_enc, [int(text[1]) for text in train_data])
device = torch.device(DEVICE) device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device) model.to(device)
trainer = Trainer( trainer = Trainer(
@ -70,6 +70,7 @@ def main(dirnames):
print("Starting training...") print("Starting training...")
trainer.train() trainer.train()
trainer.save_model(MODEL_OUT_NAME)
print("Predicting outputs...") print("Predicting outputs...")
@ -84,10 +85,7 @@ def main(dirnames):
def get_tsv_data(filename: str, compressed=False): def get_tsv_data(filename: str, compressed=False):
if compressed: check_path(filename=filename)
with lzma.open(filename=filename) as f:
return f.readlines()
else:
with open(filename) as f: with open(filename) as f:
return f.readlines() return f.readlines()