CONDA small refactoring
This commit is contained in:
parent
2b8884af80
commit
4768874dde
44
deepl.py
44
deepl.py
@ -52,13 +52,13 @@ def train(epochs):
|
||||
# data_dev = pd.read_csv('data_dev.csv', engine='python', header=None).dropna()
|
||||
# data_test = pd.read_csv('data_test.csv', engine='python', header=None).dropna()
|
||||
|
||||
x_train = data_train[5]
|
||||
x_dev = data_dev[5]
|
||||
x_test = data_test[5]
|
||||
x_train = data_train["company_profile"]
|
||||
x_dev = data_dev["company_profile"]
|
||||
x_test = data_test["company_profile"]
|
||||
|
||||
y_train = data_train[17]
|
||||
y_dev = data_dev[17]
|
||||
y_test = data_test[17]
|
||||
y_train = data_train["fraudulent"]
|
||||
y_dev = data_dev["fraudulent"]
|
||||
y_test = data_test["fraudulent"]
|
||||
|
||||
company_profile = np.array(company_profile)
|
||||
x_train = np.array(x_train)
|
||||
@ -93,7 +93,7 @@ def train(epochs):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(x_train.shape[1], 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, data_train[17].nunique()),
|
||||
nn.Linear(64, data_train["fraudulent"].nunique()),
|
||||
nn.LogSoftmax(dim=1))
|
||||
|
||||
# Define the loss
|
||||
@ -148,7 +148,7 @@ def train(epochs):
|
||||
log_ps = model(x_test)
|
||||
ps = torch.exp(log_ps)
|
||||
top_p, top_class = ps.topk(1, dim=1)
|
||||
descr = np.array(data_test[5])
|
||||
descr = np.array(data_test["fraudulent"])
|
||||
for i, (x, y) in enumerate(zip(np.array(top_class), np.array(y_test.view(*top_class.shape)))):
|
||||
d = descr[i]
|
||||
if x == y:
|
||||
@ -171,22 +171,22 @@ def train(epochs):
|
||||
f.write(f"FP = {len(FP)}\n")
|
||||
f.write(f"FN = {len(FN)}\n")
|
||||
|
||||
f.write(f"TP descriptions:")
|
||||
for i in TP:
|
||||
f.write(i+'\n')
|
||||
f.write(f"TF descriptions:")
|
||||
for i in TF:
|
||||
f.write(i+"\n")
|
||||
f.write(f"FP descriptions:")
|
||||
for i in FP:
|
||||
f.write(i+"\n")
|
||||
f.write(f"FN descriptions:")
|
||||
for i in FN:
|
||||
f.write(i+"\n")
|
||||
f.close()
|
||||
# f.write(f"TP descriptions:")
|
||||
# for i in TP:
|
||||
# f.write(i+'\n')
|
||||
# f.write(f"TF descriptions:")
|
||||
# for i in TF:
|
||||
# f.write(i+"\n")
|
||||
# f.write(f"FP descriptions:")
|
||||
# for i in FP:
|
||||
# f.write(i+"\n")
|
||||
# f.write(f"FN descriptions:")
|
||||
# for i in FN:
|
||||
# f.write(i+"\n")
|
||||
# f.close()
|
||||
|
||||
torch.save(model, 'model')
|
||||
input_example = data_train[:5]
|
||||
# input_example = data_train[:5]
|
||||
# siganture = infer_signature(input_example, np.array(['company_profile']))
|
||||
|
||||
# path = urlparse(mlflow.get_tracking_uri()).scheme
|
||||
|
Loading…
Reference in New Issue
Block a user