import os import sys from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer import torch import random IN_FILE_NAME = "in.tsv" OUT_FILE_NAME = "out.tsv" TRAIN_PATH = "train" EXP_FILE_NAME = "expected.tsv" FILE_SEP = "\t" # PT_MODEL_NAME = "bert-base-cased" PT_MODEL_NAME = "roberta-base" MODEL_OUT_NAME = "./model.tr" class CustomDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) def main(dirnames): print("Reading train data...") train_set_features = get_tsv_data(os.path.join( TRAIN_PATH, IN_FILE_NAME), compressed=True) train_set_labels = get_tsv_data(os.path.join( TRAIN_PATH, EXP_FILE_NAME), compressed=True) print("Reading input data...") in_sets = [] for d in dirnames: print(f"\tReading dir: {d}...") in_sets.append(get_tsv_data( os.path.join(d, IN_FILE_NAME), compressed=True)) train_data = list(zip(train_set_features, train_set_labels)) train_data = random.sample(train_data, 15000) mname = PT_MODEL_NAME if os.path.exists(MODEL_OUT_NAME): mname = MODEL_OUT_NAME tokenizer = AutoTokenizer.from_pretrained(mname) model = AutoModelForSequenceClassification.from_pretrained( mname, num_labels=2) train_set_enc = tokenizer( [text[0] for text in train_data], truncation=True, padding=True) ds = CustomDataset( train_set_enc, [int(text[1]) for text in train_data]) device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') model.to(device) trainer = Trainer( model=model, args=TrainingArguments("model"), train_dataset=ds ) print("Starting training...") trainer.train() trainer.save_model(MODEL_OUT_NAME) print("Predicting outputs...") for i in range(len(in_sets)): p_in = os.path.join(dirnames[i], IN_FILE_NAME) p_out = os.path.join(dirnames[i], OUT_FILE_NAME) with open(p_out, "w") as f: print( f"\tPredicting for: {p_in}...") f.write('\n'.join(trainer.predict(in_sets[i]))) print(f"Saved predictions to file: {p_out}") def get_tsv_data(filename: str, compressed=False): check_path(filename=filename) with open(filename) as f: return f.readlines() def check_path(filename: str): if not os.path.exists(filename): raise Exception(f"Path {filename} does not exist!") if __name__ == "__main__": if len(sys.argv) < 2: raise Exception("Name of working dir not specified!") main(sys.argv[1:])