Add script
This commit is contained in:
parent
756ef4277a
commit
415cad97e2
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,3 +6,5 @@
|
||||
*.o
|
||||
.DS_Store
|
||||
.token
|
||||
|
||||
**/.vscode/*
|
||||
|
105
main.py
Normal file
105
main.py
Normal file
@ -0,0 +1,105 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
import sys
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||
import torch
|
||||
import csv
|
||||
|
||||
IN_FILE_NAME = "in.tsv.xz"
|
||||
OUT_FILE_NAME = "out.tsv"
|
||||
TRAIN_PATH = "train"
|
||||
EXP_FILE_NAME = "expected.tsv"
|
||||
FILE_SEP = "\t"
|
||||
IN_HEADER_FILE_NAME = "in-header.tsv"
|
||||
OUT_HEADER_FILE_NAME = "out-header.tsv"
|
||||
PT_MODEL_NAME = "bert-base-cased"
|
||||
# PT_MODEL_NAME = "roberta-base"
|
||||
|
||||
|
||||
class CustomDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings, labels):
|
||||
self.encodings = encodings
|
||||
self.labels = labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {key: torch.tensor(val[idx])
|
||||
for key, val in self.encodings.items()}
|
||||
item['labels'] = torch.tensor(self.labels[idx])
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
|
||||
def main(dirnames):
|
||||
check_path(IN_HEADER_FILE_NAME)
|
||||
in_cols = (pd.read_csv(IN_HEADER_FILE_NAME, sep=FILE_SEP)).columns
|
||||
check_path(OUT_HEADER_FILE_NAME)
|
||||
out_cols = (pd.read_csv(OUT_HEADER_FILE_NAME, sep=FILE_SEP)).columns
|
||||
|
||||
print("Reading train data...")
|
||||
train_set_features = get_tsv_data(os.path.join(
|
||||
TRAIN_PATH, IN_FILE_NAME), names=in_cols)
|
||||
train_set_labels = get_tsv_data(os.path.join(
|
||||
TRAIN_PATH, EXP_FILE_NAME), names=out_cols, compression=None)
|
||||
|
||||
print("Reading input data...")
|
||||
in_sets = []
|
||||
for d in dirnames:
|
||||
print(f"\tReading dir: {d}...")
|
||||
in_sets.append(get_tsv_data(
|
||||
os.path.join(d, IN_FILE_NAME), names=in_cols))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
PT_MODEL_NAME, num_labels=2)
|
||||
train_set_enc = tokenizer(
|
||||
[t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True)
|
||||
dataset = CustomDataset(
|
||||
train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]])
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=TrainingArguments('./res'),
|
||||
train_dataset=dataset,
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
for i in range(len(in_sets)):
|
||||
p = os.path.join(dirnames[i], IN_FILE_NAME)
|
||||
with open(p) as f:
|
||||
print(
|
||||
f"\tPredicting for: {p}...")
|
||||
X = [t for t in in_sets[i][in_cols].agg(' '.join, axis=1)]
|
||||
out_file_path = os.path.join(dirnames[i], OUT_FILE_NAME)
|
||||
f.write('\n'.join(trainer.predict(X)))
|
||||
print(f"Saved predictions to file: {out_file_path}")
|
||||
|
||||
|
||||
def get_tsv_data(filename: str, names, compression="infer"):
|
||||
check_path(filename)
|
||||
return pd.read_csv(
|
||||
filename,
|
||||
sep=FILE_SEP,
|
||||
compression=compression,
|
||||
error_bad_lines=False,
|
||||
quoting=csv.QUOTE_NONE,
|
||||
header=None,
|
||||
names=names,
|
||||
dtype=str
|
||||
)
|
||||
|
||||
|
||||
def check_path(filename: str):
|
||||
if not os.path.exists(filename):
|
||||
raise Exception(f"Path {filename} does not exist!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
raise Exception("Name of working dir not specified!")
|
||||
main(sys.argv[1:])
|
Loading…
Reference in New Issue
Block a user