From 415cad97e2d338206f9349c42d7a49033e0ae0e2 Mon Sep 17 00:00:00 2001 From: nlitkowski Date: Mon, 21 Jun 2021 21:46:10 +0200 Subject: [PATCH] Add script --- .gitignore | 2 + main.py | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 main.py diff --git a/.gitignore b/.gitignore index 1c18d74..d8f4b98 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ *.o .DS_Store .token + +**/.vscode/* diff --git a/main.py b/main.py new file mode 100644 index 0000000..a65ff80 --- /dev/null +++ b/main.py @@ -0,0 +1,105 @@ +import pandas as pd +import os +import sys +from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer +import torch +import csv + +IN_FILE_NAME = "in.tsv.xz" +OUT_FILE_NAME = "out.tsv" +TRAIN_PATH = "train" +EXP_FILE_NAME = "expected.tsv" +FILE_SEP = "\t" +IN_HEADER_FILE_NAME = "in-header.tsv" +OUT_HEADER_FILE_NAME = "out-header.tsv" +PT_MODEL_NAME = "bert-base-cased" +# PT_MODEL_NAME = "roberta-base" + + +class CustomDataset(torch.utils.data.Dataset): + def __init__(self, encodings, labels): + self.encodings = encodings + self.labels = labels + + def __getitem__(self, idx): + item = {key: torch.tensor(val[idx]) + for key, val in self.encodings.items()} + item['labels'] = torch.tensor(self.labels[idx]) + return item + + def __len__(self): + return len(self.labels) + + +def main(dirnames): + check_path(IN_HEADER_FILE_NAME) + in_cols = (pd.read_csv(IN_HEADER_FILE_NAME, sep=FILE_SEP)).columns + check_path(OUT_HEADER_FILE_NAME) + out_cols = (pd.read_csv(OUT_HEADER_FILE_NAME, sep=FILE_SEP)).columns + + print("Reading train data...") + train_set_features = get_tsv_data(os.path.join( + TRAIN_PATH, IN_FILE_NAME), names=in_cols) + train_set_labels = get_tsv_data(os.path.join( + TRAIN_PATH, EXP_FILE_NAME), names=out_cols, compression=None) + + print("Reading input data...") + in_sets = [] + for d in dirnames: + print(f"\tReading dir: {d}...") + in_sets.append(get_tsv_data( + os.path.join(d, IN_FILE_NAME), names=in_cols)) + + tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME) + model = AutoModelForSequenceClassification.from_pretrained( + PT_MODEL_NAME, num_labels=2) + train_set_enc = tokenizer( + [t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True) + dataset = CustomDataset( + train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]]) + + trainer = Trainer( + model=model, + args=TrainingArguments('./res'), + train_dataset=dataset, + num_train_epochs=5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + ) + + trainer.train() + + for i in range(len(in_sets)): + p = os.path.join(dirnames[i], IN_FILE_NAME) + with open(p) as f: + print( + f"\tPredicting for: {p}...") + X = [t for t in in_sets[i][in_cols].agg(' '.join, axis=1)] + out_file_path = os.path.join(dirnames[i], OUT_FILE_NAME) + f.write('\n'.join(trainer.predict(X))) + print(f"Saved predictions to file: {out_file_path}") + + +def get_tsv_data(filename: str, names, compression="infer"): + check_path(filename) + return pd.read_csv( + filename, + sep=FILE_SEP, + compression=compression, + error_bad_lines=False, + quoting=csv.QUOTE_NONE, + header=None, + names=names, + dtype=str + ) + + +def check_path(filename: str): + if not os.path.exists(filename): + raise Exception(f"Path {filename} does not exist!") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + raise Exception("Name of working dir not specified!") + main(sys.argv[1:])