diff --git a/deepl.py b/deepl.py index b34f560..05ba415 100644 --- a/deepl.py +++ b/deepl.py @@ -52,13 +52,13 @@ def train(epochs): # data_dev = pd.read_csv('data_dev.csv', engine='python', header=None).dropna() # data_test = pd.read_csv('data_test.csv', engine='python', header=None).dropna() - x_train = data_train[5] - x_dev = data_dev[5] - x_test = data_test[5] + x_train = data_train["company_profile"] + x_dev = data_dev["company_profile"] + x_test = data_test["company_profile"] - y_train = data_train[17] - y_dev = data_dev[17] - y_test = data_test[17] + y_train = data_train["fraudulent"] + y_dev = data_dev["fraudulent"] + y_test = data_test["fraudulent"] company_profile = np.array(company_profile) x_train = np.array(x_train) @@ -93,7 +93,7 @@ def train(epochs): model = nn.Sequential( nn.Linear(x_train.shape[1], 64), nn.ReLU(), - nn.Linear(64, data_train[17].nunique()), + nn.Linear(64, data_train["fraudulent"].nunique()), nn.LogSoftmax(dim=1)) # Define the loss @@ -148,7 +148,7 @@ def train(epochs): log_ps = model(x_test) ps = torch.exp(log_ps) top_p, top_class = ps.topk(1, dim=1) - descr = np.array(data_test[5]) + descr = np.array(data_test["fraudulent"]) for i, (x, y) in enumerate(zip(np.array(top_class), np.array(y_test.view(*top_class.shape)))): d = descr[i] if x == y: @@ -171,22 +171,22 @@ def train(epochs): f.write(f"FP = {len(FP)}\n") f.write(f"FN = {len(FN)}\n") - f.write(f"TP descriptions:") - for i in TP: - f.write(i+'\n') - f.write(f"TF descriptions:") - for i in TF: - f.write(i+"\n") - f.write(f"FP descriptions:") - for i in FP: - f.write(i+"\n") - f.write(f"FN descriptions:") - for i in FN: - f.write(i+"\n") - f.close() + # f.write(f"TP descriptions:") + # for i in TP: + # f.write(i+'\n') + # f.write(f"TF descriptions:") + # for i in TF: + # f.write(i+"\n") + # f.write(f"FP descriptions:") + # for i in FP: + # f.write(i+"\n") + # f.write(f"FN descriptions:") + # for i in FN: + # f.write(i+"\n") + # f.close() torch.save(model, 'model') - input_example = data_train[:5] + # input_example = data_train[:5] # siganture = infer_signature(input_example, np.array(['company_profile'])) # path = urlparse(mlflow.get_tracking_uri()).scheme