upd scr
This commit is contained in:
parent
1c183a567a
commit
c1271fb458
10
script5_4.py
10
script5_4.py
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
# Define the neural network model
|
# Define the neural network model
|
||||||
@ -75,14 +75,6 @@ for inputs, targets_batch in test_dataloader:
|
|||||||
predictions = torch.FloatTensor(predictions).squeeze()
|
predictions = torch.FloatTensor(predictions).squeeze()
|
||||||
targets = torch.FloatTensor(targets).squeeze()
|
targets = torch.FloatTensor(targets).squeeze()
|
||||||
|
|
||||||
accuracy = accuracy_score(targets, torch.round(predictions))
|
|
||||||
precision = precision_score(targets, torch.round(predictions), average='micro')
|
|
||||||
recall = recall_score(targets, torch.round(predictions), average='micro')
|
|
||||||
f1 = f1_score(targets, torch.round(predictions), average='micro')
|
|
||||||
rmse = mean_squared_error(targets, predictions, squared=False)
|
rmse = mean_squared_error(targets, predictions, squared=False)
|
||||||
|
|
||||||
print("Accuracy:", accuracy)
|
|
||||||
print("Micro-average Precision:", precision)
|
|
||||||
print("Micro-average Recall:", recall)
|
|
||||||
print("F1 Score:", f1)
|
|
||||||
print("RMSE:", rmse)
|
print("RMSE:", rmse)
|
||||||
|
Loading…
Reference in New Issue
Block a user