diff --git a/deepl.py b/deepl.py index 5b6fb7a..b34f560 100644 --- a/deepl.py +++ b/deepl.py @@ -46,11 +46,11 @@ def train(epochs): data = data.dropna() company_profile = data["company_profile"] - # data_train, data_test = train_test_split(data, test_size=3000, random_state=1) - # data_dev, data_test = train_test_split(data_test, test_size=1500, random_state=1) - data_train = pd.read_csv('data_train.csv', engine='python', header=None).dropna() - data_dev = pd.read_csv('data_dev.csv', engine='python', header=None).dropna() - data_test = pd.read_csv('data_test.csv', engine='python', header=None).dropna() + data_train, data_test = train_test_split(data, test_size=3000, random_state=1) + data_dev, data_test = train_test_split(data_test, test_size=1500, random_state=1) + # data_train = pd.read_csv('data_train.csv', engine='python', header=None).dropna() + # data_dev = pd.read_csv('data_dev.csv', engine='python', header=None).dropna() + # data_test = pd.read_csv('data_test.csv', engine='python', header=None).dropna() x_train = data_train[5] x_dev = data_dev[5]