synthetic_errors/preprocess_dataset.py
2022-04-26 19:24:23 +02:00

23 lines
836 B
Python

from sklearn.model_selection import train_test_split
import pandas as pd
def read_input_file(input_filename):
with open(input_filename, encoding="utf-8", mode='r') as input:
yield from input
def save_file(input, file_name):
with open(file_name, encoding="utf-8", mode='w') as file:
for line in input:
file.write(line)
X_train, X_test, y_train, y_test = train_test_split(list(read_input_file('./train.er')), list(read_input_file('./train.co')), test_size=0.001, random_state=1)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.001, random_state=1)
save_file(X_train, "./data/train.er")
save_file(y_train, "./data/train.co")
save_file(X_test, "./data/test.er")
save_file(y_test, "./data/test.co")
save_file(X_val, "./data/dev.er")
save_file(y_val, "./data/dev.co")