From 567f498ee2ecf5e63bd8c8b95da1791a6afcffa4 Mon Sep 17 00:00:00 2001 From: nlitkowski Date: Tue, 22 Jun 2021 20:27:03 +0200 Subject: [PATCH] Fix script --- main.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index a741ac6..f7d74e8 100644 --- a/main.py +++ b/main.py @@ -2,20 +2,16 @@ import os import sys from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer import torch -try: - import lzma -except ImportError: - from backports import lzma import random -IN_FILE_NAME = "in.tsv.xz" +IN_FILE_NAME = "in.tsv" OUT_FILE_NAME = "out.tsv" TRAIN_PATH = "train" EXP_FILE_NAME = "expected.tsv" FILE_SEP = "\t" # PT_MODEL_NAME = "bert-base-cased" PT_MODEL_NAME = "roberta-base" -DEVICE = "cpu" +MODEL_OUT_NAME = "./model.tr" class CustomDataset(torch.utils.data.Dataset): @@ -50,15 +46,19 @@ def main(dirnames): train_data = list(zip(train_set_features, train_set_labels)) train_data = random.sample(train_data, 15000) - tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME) + mname = PT_MODEL_NAME + if os.path.exists(MODEL_OUT_NAME): + mname = MODEL_OUT_NAME + tokenizer = AutoTokenizer.from_pretrained(mname) model = AutoModelForSequenceClassification.from_pretrained( - PT_MODEL_NAME, num_labels=2) + mname, num_labels=2) train_set_enc = tokenizer( [text[0] for text in train_data], truncation=True, padding=True) ds = CustomDataset( train_set_enc, [int(text[1]) for text in train_data]) - device = torch.device(DEVICE) + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') model.to(device) trainer = Trainer( @@ -70,6 +70,7 @@ def main(dirnames): print("Starting training...") trainer.train() + trainer.save_model(MODEL_OUT_NAME) print("Predicting outputs...") @@ -84,12 +85,9 @@ def main(dirnames): def get_tsv_data(filename: str, compressed=False): - if compressed: - with lzma.open(filename=filename) as f: - return f.readlines() - else: - with open(filename) as f: - return f.readlines() + check_path(filename=filename) + with open(filename) as f: + return f.readlines() def check_path(filename: str):