import pandas as pd import os import sys from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer import torch import csv IN_FILE_NAME = "in.tsv.xz" OUT_FILE_NAME = "out.tsv" TRAIN_PATH = "train" EXP_FILE_NAME = "expected.tsv" FILE_SEP = "\t" IN_HEADER_FILE_NAME = "in-header.tsv" OUT_HEADER_FILE_NAME = "out-header.tsv" PT_MODEL_NAME = "bert-base-cased" # PT_MODEL_NAME = "roberta-base" class CustomDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) def main(dirnames): check_path(IN_HEADER_FILE_NAME) in_cols = (pd.read_csv(IN_HEADER_FILE_NAME, sep=FILE_SEP)).columns check_path(OUT_HEADER_FILE_NAME) out_cols = (pd.read_csv(OUT_HEADER_FILE_NAME, sep=FILE_SEP)).columns print("Reading train data...") train_set_features = get_tsv_data(os.path.join( TRAIN_PATH, IN_FILE_NAME), names=in_cols) train_set_labels = get_tsv_data(os.path.join( TRAIN_PATH, EXP_FILE_NAME), names=out_cols, compression=None) print("Reading input data...") in_sets = [] for d in dirnames: print(f"\tReading dir: {d}...") in_sets.append(get_tsv_data( os.path.join(d, IN_FILE_NAME), names=in_cols)) tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained( PT_MODEL_NAME, num_labels=2) train_set_enc = tokenizer( [t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True) dataset = CustomDataset( train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]]) trainer = Trainer( model=model, args=TrainingArguments('./res'), train_dataset=dataset, num_train_epochs=5, per_device_train_batch_size=16, per_device_eval_batch_size=16, ) trainer.train() for i in range(len(in_sets)): p = os.path.join(dirnames[i], IN_FILE_NAME) with open(p) as f: print( f"\tPredicting for: {p}...") X = [t for t in in_sets[i][in_cols].agg(' '.join, axis=1)] out_file_path = os.path.join(dirnames[i], OUT_FILE_NAME) f.write('\n'.join(trainer.predict(X))) print(f"Saved predictions to file: {out_file_path}") def get_tsv_data(filename: str, names, compression="infer"): check_path(filename) return pd.read_csv( filename, sep=FILE_SEP, compression=compression, error_bad_lines=False, quoting=csv.QUOTE_NONE, header=None, names=names, dtype=str ) def check_path(filename: str): if not os.path.exists(filename): raise Exception(f"Path {filename} does not exist!") if __name__ == "__main__": if len(sys.argv) < 2: raise Exception("Name of working dir not specified!") main(sys.argv[1:])