diff --git a/pytorch-example-evaluate.py b/pytorch-example-evaluate.py index 7db5957..f6a012a 100644 --- a/pytorch-example-evaluate.py +++ b/pytorch-example-evaluate.py @@ -47,7 +47,7 @@ model = LogisticRegressionModel(input_dim, output_dim) pred = model(fTest) accuracy = accuracy_score(tTest, np.argmax(pred.detach().numpy(), axis = 1)) -f1 = f1_score(tTest, np.argmax(pred.detach().numpy(), axis = 1, average = None)) +f1 = f1_score(tTest, np.argmax(pred.detach().numpy(), axis = 1), average = None) rmse = mean_squared_error(tTest, np.argmax(pred.detach().numpy())) print(f'Accuracy: {accuracy}')