fix script

This commit is contained in:
nlitkowski 2021-06-22 18:13:51 +02:00
parent c36ba1d489
commit 4b000457df
4 changed files with 36 additions and 300040 deletions

File diff suppressed because one or more lines are too long

73
main.py
View File

@ -1,19 +1,21 @@
import pandas as pd
import os import os
import sys import sys
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch import torch
import csv try:
import lzma
except ImportError:
from backports import lzma
import random
IN_FILE_NAME = "in.tsv.xz" IN_FILE_NAME = "in.tsv.xz"
OUT_FILE_NAME = "out.tsv" OUT_FILE_NAME = "out.tsv"
TRAIN_PATH = "train" TRAIN_PATH = "train"
EXP_FILE_NAME = "expected.tsv" EXP_FILE_NAME = "expected.tsv"
FILE_SEP = "\t" FILE_SEP = "\t"
IN_HEADER_FILE_NAME = "in-header.tsv"
OUT_HEADER_FILE_NAME = "out-header.tsv"
# PT_MODEL_NAME = "bert-base-cased" # PT_MODEL_NAME = "bert-base-cased"
PT_MODEL_NAME = "roberta-base" PT_MODEL_NAME = "roberta-base"
DEVICE = "cpu"
class CustomDataset(torch.utils.data.Dataset): class CustomDataset(torch.utils.data.Dataset):
@ -32,65 +34,62 @@ class CustomDataset(torch.utils.data.Dataset):
def main(dirnames): def main(dirnames):
check_path(IN_HEADER_FILE_NAME)
in_cols = (pd.read_csv(IN_HEADER_FILE_NAME, sep=FILE_SEP)).columns
check_path(OUT_HEADER_FILE_NAME)
out_cols = (pd.read_csv(OUT_HEADER_FILE_NAME, sep=FILE_SEP)).columns
print("Reading train data...") print("Reading train data...")
train_set_features = get_tsv_data(os.path.join( train_set_features = get_tsv_data(os.path.join(
TRAIN_PATH, IN_FILE_NAME), names=in_cols) TRAIN_PATH, IN_FILE_NAME), compressed=True)
train_set_labels = get_tsv_data(os.path.join( train_set_labels = get_tsv_data(os.path.join(
TRAIN_PATH, EXP_FILE_NAME), names=out_cols, compression=None) TRAIN_PATH, EXP_FILE_NAME), compressed=True)
print("Reading input data...") print("Reading input data...")
in_sets = [] in_sets = []
for d in dirnames: for d in dirnames:
print(f"\tReading dir: {d}...") print(f"\tReading dir: {d}...")
in_sets.append(get_tsv_data( in_sets.append(get_tsv_data(
os.path.join(d, IN_FILE_NAME), names=in_cols)) os.path.join(d, IN_FILE_NAME), compressed=True))
train_data = list(zip(train_set_features, train_set_labels))
train_data = random.sample(train_data, 15000)
tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(PT_MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
PT_MODEL_NAME, num_labels=2) PT_MODEL_NAME, num_labels=2)
train_set_enc = tokenizer( train_set_enc = tokenizer(
[t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True) [text[0] for text in train_data], truncation=True, padding=True)
dataset = CustomDataset( ds = CustomDataset(
train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]]) train_set_enc, [int(text[1]) for text in train_data])
device = torch.device("cuda")
device = torch.device(DEVICE)
model.to(device) model.to(device)
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=TrainingArguments("model"), args=TrainingArguments("model"),
train_dataset=dataset train_dataset=ds
) )
print("Starting training...") print("Starting training...")
trainer.train() trainer.train()
print("Predictions...")
print("Predicting outputs...")
for i in range(len(in_sets)): for i in range(len(in_sets)):
p = os.path.join(dirnames[i], IN_FILE_NAME) p_in = os.path.join(dirnames[i], IN_FILE_NAME)
with open(p) as f: p_out = os.path.join(dirnames[i], OUT_FILE_NAME)
with open(p_out, "w") as f:
print( print(
f"\tPredicting for: {p}...") f"\tPredicting for: {p_in}...")
X = [t for t in in_sets[i][in_cols].agg(' '.join, axis=1)] f.write('\n'.join(trainer.predict(in_sets[i])))
out_file_path = os.path.join(dirnames[i], OUT_FILE_NAME) print(f"Saved predictions to file: {p_out}")
f.write('\n'.join(trainer.predict(X)))
print(f"Saved predictions to file: {out_file_path}")
def get_tsv_data(filename: str, names, compression="infer"): def get_tsv_data(filename: str, compressed=False):
check_path(filename) if compressed:
return pd.read_csv( with lzma.open(filename=filename) as f:
filename, return f.readlines()
sep=FILE_SEP, else:
compression=compression, with open(filename) as f:
error_bad_lines=False, return f.readlines()
quoting=csv.QUOTE_NONE,
header=None,
names=names,
dtype=str
)
def check_path(filename: str): def check_path(filename: str):

File diff suppressed because one or more lines are too long

289579
train/in.tsv

File diff suppressed because one or more lines are too long