CONDA small refactoring

This commit is contained in:
Mikołaj Pokrywka 2022-05-23 00:15:02 +02:00
parent 2b8884af80
commit 4768874dde

View File

@ -52,13 +52,13 @@ def train(epochs):
# data_dev = pd.read_csv('data_dev.csv', engine='python', header=None).dropna()
# data_test = pd.read_csv('data_test.csv', engine='python', header=None).dropna()
x_train = data_train[5]
x_dev = data_dev[5]
x_test = data_test[5]
x_train = data_train["company_profile"]
x_dev = data_dev["company_profile"]
x_test = data_test["company_profile"]
y_train = data_train[17]
y_dev = data_dev[17]
y_test = data_test[17]
y_train = data_train["fraudulent"]
y_dev = data_dev["fraudulent"]
y_test = data_test["fraudulent"]
company_profile = np.array(company_profile)
x_train = np.array(x_train)
@ -93,7 +93,7 @@ def train(epochs):
model = nn.Sequential(
nn.Linear(x_train.shape[1], 64),
nn.ReLU(),
nn.Linear(64, data_train[17].nunique()),
nn.Linear(64, data_train["fraudulent"].nunique()),
nn.LogSoftmax(dim=1))
# Define the loss
@ -148,7 +148,7 @@ def train(epochs):
log_ps = model(x_test)
ps = torch.exp(log_ps)
top_p, top_class = ps.topk(1, dim=1)
descr = np.array(data_test[5])
descr = np.array(data_test["fraudulent"])
for i, (x, y) in enumerate(zip(np.array(top_class), np.array(y_test.view(*top_class.shape)))):
d = descr[i]
if x == y:
@ -171,22 +171,22 @@ def train(epochs):
f.write(f"FP = {len(FP)}\n")
f.write(f"FN = {len(FN)}\n")
f.write(f"TP descriptions:")
for i in TP:
f.write(i+'\n')
f.write(f"TF descriptions:")
for i in TF:
f.write(i+"\n")
f.write(f"FP descriptions:")
for i in FP:
f.write(i+"\n")
f.write(f"FN descriptions:")
for i in FN:
f.write(i+"\n")
f.close()
# f.write(f"TP descriptions:")
# for i in TP:
# f.write(i+'\n')
# f.write(f"TF descriptions:")
# for i in TF:
# f.write(i+"\n")
# f.write(f"FP descriptions:")
# for i in FP:
# f.write(i+"\n")
# f.write(f"FN descriptions:")
# for i in FN:
# f.write(i+"\n")
# f.close()
torch.save(model, 'model')
input_example = data_train[:5]
# input_example = data_train[:5]
# siganture = infer_signature(input_example, np.array(['company_profile']))
# path = urlparse(mlflow.get_tracking_uri()).scheme