Fix script
This commit is contained in:
parent
4b000457df
commit
567f498ee2
24
main.py
24
main.py
@ -2,20 +2,16 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
import lzma
|
|
||||||
except ImportError:
|
|
||||||
from backports import lzma
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
IN_FILE_NAME = "in.tsv.xz"
|
IN_FILE_NAME = "in.tsv"
|
||||||
OUT_FILE_NAME = "out.tsv"
|
OUT_FILE_NAME = "out.tsv"
|
||||||
TRAIN_PATH = "train"
|
TRAIN_PATH = "train"
|
||||||
EXP_FILE_NAME = "expected.tsv"
|
EXP_FILE_NAME = "expected.tsv"
|
||||||
FILE_SEP = "\t"
|
FILE_SEP = "\t"
|
||||||
# PT_MODEL_NAME = "bert-base-cased"
|
# PT_MODEL_NAME = "bert-base-cased"
|
||||||
PT_MODEL_NAME = "roberta-base"
|
PT_MODEL_NAME = "roberta-base"
|
||||||
DEVICE = "cpu"
|
MODEL_OUT_NAME = "./model.tr"
|
||||||
|
|
||||||
|
|
||||||
class CustomDataset(torch.utils.data.Dataset):
|
class CustomDataset(torch.utils.data.Dataset):
|
||||||
@ -50,15 +46,19 @@ def main(dirnames):
|
|||||||
train_data = list(zip(train_set_features, train_set_labels))
|
train_data = list(zip(train_set_features, train_set_labels))
|
||||||
train_data = random.sample(train_data, 15000)
|
train_data = random.sample(train_data, 15000)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME)
|
mname = PT_MODEL_NAME
|
||||||
|
if os.path.exists(MODEL_OUT_NAME):
|
||||||
|
mname = MODEL_OUT_NAME
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(mname)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
PT_MODEL_NAME, num_labels=2)
|
mname, num_labels=2)
|
||||||
train_set_enc = tokenizer(
|
train_set_enc = tokenizer(
|
||||||
[text[0] for text in train_data], truncation=True, padding=True)
|
[text[0] for text in train_data], truncation=True, padding=True)
|
||||||
ds = CustomDataset(
|
ds = CustomDataset(
|
||||||
train_set_enc, [int(text[1]) for text in train_data])
|
train_set_enc, [int(text[1]) for text in train_data])
|
||||||
|
|
||||||
device = torch.device(DEVICE)
|
device = torch.device(
|
||||||
|
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
@ -70,6 +70,7 @@ def main(dirnames):
|
|||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
trainer.save_model(MODEL_OUT_NAME)
|
||||||
|
|
||||||
print("Predicting outputs...")
|
print("Predicting outputs...")
|
||||||
|
|
||||||
@ -84,10 +85,7 @@ def main(dirnames):
|
|||||||
|
|
||||||
|
|
||||||
def get_tsv_data(filename: str, compressed=False):
|
def get_tsv_data(filename: str, compressed=False):
|
||||||
if compressed:
|
check_path(filename=filename)
|
||||||
with lzma.open(filename=filename) as f:
|
|
||||||
return f.readlines()
|
|
||||||
else:
|
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
return f.readlines()
|
return f.readlines()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user