104 lines
2.9 KiB
Python
104 lines
2.9 KiB
Python
import os
|
|
import sys
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
|
import torch
|
|
import random
|
|
|
|
IN_FILE_NAME = "in.tsv"
|
|
OUT_FILE_NAME = "out.tsv"
|
|
TRAIN_PATH = "train"
|
|
EXP_FILE_NAME = "expected.tsv"
|
|
FILE_SEP = "\t"
|
|
# PT_MODEL_NAME = "bert-base-cased"
|
|
PT_MODEL_NAME = "roberta-base"
|
|
MODEL_OUT_NAME = "./model.tr"
|
|
|
|
|
|
class CustomDataset(torch.utils.data.Dataset):
|
|
def __init__(self, encodings, labels):
|
|
self.encodings = encodings
|
|
self.labels = labels
|
|
|
|
def __getitem__(self, idx):
|
|
item = {key: torch.tensor(val[idx])
|
|
for key, val in self.encodings.items()}
|
|
item['labels'] = torch.tensor(self.labels[idx])
|
|
return item
|
|
|
|
def __len__(self):
|
|
return len(self.labels)
|
|
|
|
|
|
def main(dirnames):
|
|
print("Reading train data...")
|
|
train_set_features = get_tsv_data(os.path.join(
|
|
TRAIN_PATH, IN_FILE_NAME), compressed=True)
|
|
train_set_labels = get_tsv_data(os.path.join(
|
|
TRAIN_PATH, EXP_FILE_NAME), compressed=True)
|
|
|
|
print("Reading input data...")
|
|
in_sets = []
|
|
for d in dirnames:
|
|
print(f"\tReading dir: {d}...")
|
|
in_sets.append(get_tsv_data(
|
|
os.path.join(d, IN_FILE_NAME), compressed=True))
|
|
|
|
train_data = list(zip(train_set_features, train_set_labels))
|
|
train_data = random.sample(train_data, 15000)
|
|
|
|
mname = PT_MODEL_NAME
|
|
pt = os.path.exists(MODEL_OUT_NAME)
|
|
if pt:
|
|
mname = MODEL_OUT_NAME
|
|
tokenizer = AutoTokenizer.from_pretrained(mname)
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
mname, num_labels=2)
|
|
train_set_enc = tokenizer(
|
|
[text[0] for text in train_data], truncation=True, padding=True)
|
|
ds = CustomDataset(
|
|
train_set_enc, [int(text[1]) for text in train_data])
|
|
|
|
device = torch.device(
|
|
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
model.to(device)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=TrainingArguments("model"),
|
|
train_dataset=ds
|
|
)
|
|
|
|
print("Starting training...")
|
|
|
|
if not pt:
|
|
trainer.train()
|
|
trainer.save_model(MODEL_OUT_NAME)
|
|
|
|
print("Predicting outputs...")
|
|
|
|
for i in range(len(in_sets)):
|
|
p_in = os.path.join(dirnames[i], IN_FILE_NAME)
|
|
p_out = os.path.join(dirnames[i], OUT_FILE_NAME)
|
|
with open(p_out, "w") as f:
|
|
print(
|
|
f"\tPredicting for: {p_in}...")
|
|
f.write('\n'.join(trainer.predict(in_sets[i])))
|
|
print(f"Saved predictions to file: {p_out}")
|
|
|
|
|
|
def get_tsv_data(filename: str, compressed=False):
|
|
check_path(filename=filename)
|
|
with open(filename) as f:
|
|
return f.readlines()
|
|
|
|
|
|
def check_path(filename: str):
|
|
if not os.path.exists(filename):
|
|
raise Exception(f"Path {filename} does not exist!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
raise Exception("Name of working dir not specified!")
|
|
main(sys.argv[1:])
|