update dllib-mlflow.py
This commit is contained in:
parent
5f9785d4fd
commit
c2a7fbbf25
@ -296,34 +296,34 @@ def my_main(epochs):
|
|||||||
print ("The loss calculated: ", loss)
|
print ("The loss calculated: ", loss)
|
||||||
|
|
||||||
|
|
||||||
# Not using dataloader
|
with mlflow.start_run() as run:
|
||||||
x_train, y_train = Variable(torch.from_numpy(features_train_g)).float(), Variable(torch.from_numpy(labels_train_g)).long()
|
x_train, y_train = Variable(torch.from_numpy(features_train_g)).float(), Variable(torch.from_numpy(labels_train_g)).long()
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs + 1):
|
||||||
print("Epoch #", epoch)
|
print("Epoch #", epoch)
|
||||||
y_pred = model(x_train)
|
y_pred = model(x_train)
|
||||||
|
|
||||||
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
||||||
print_(loss.item())
|
print_(loss.item())
|
||||||
|
|
||||||
# Zero gradients
|
# Zero gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # Gradients
|
loss.backward() # Gradients
|
||||||
optimizer.step() # Update
|
optimizer.step() # Update
|
||||||
|
|
||||||
# Prediction
|
# Prediction
|
||||||
x_test = Variable(torch.from_numpy(features_test_g)).float()
|
x_test = Variable(torch.from_numpy(features_test_g)).float()
|
||||||
pred = model(x_test)
|
pred = model(x_test)
|
||||||
|
|
||||||
pred = pred.detach().numpy()
|
pred = pred.detach().numpy()
|
||||||
|
|
||||||
print("The accuracy is", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
print("The accuracy is", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
||||||
mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
||||||
|
|
||||||
pred = pd.DataFrame(pred)
|
pred = pd.DataFrame(pred)
|
||||||
|
|
||||||
pred.to_csv('result.csv')
|
pred.to_csv('result.csv')
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
torch.save(model, "games_model.pkl")
|
torch.save(model, "games_model.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user