Zaktualizuj 'evaluation.py'

This commit is contained in:
Kornelia Girejko 2022-05-06 22:41:47 +02:00
parent 5ae45d867e
commit 3351af1628
1 changed files with 33 additions and 9 deletions

View File

@ -5,13 +5,16 @@ import pandas as pd
import numpy as np
import sys
import os
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from datetime import datetime
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.metrics import classification_report
scaler = StandardScaler()
EPOCHS = int(sys.argv[1])
# Model
class Model(nn.Module):
def __init__(self):
@ -54,13 +57,12 @@ X_testing = torch.from_numpy(X_testing.astype(np.float32))
y_training = torch.from_numpy(y_training.astype(np.float32))
y_testing = torch.from_numpy(y_testing.astype(np.float32))
model = Model()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Trening
num_epochs = EPOCHS
num_epochs = 1000
for epoch in range(num_epochs):
y_predicted = model(X_training)
loss = criterion(y_predicted,y_training)
@ -75,8 +77,30 @@ with torch.no_grad():
y_predicted = model(X_testing)
y_predicted_cls = y_predicted.round()
acc = y_predicted_cls.eq(y_testing).sum()/float(y_testing.shape[0])
print(f'{acc:.4f}')
result = open("output",'w+')
result.write(f'{y_predicted}')
#print(f'{acc:.4f}')
torch.save(model, "modelP.pkl")
rmse = mean_squared_error(y_testing, y_predicted)
#print(rmse)
mae = mean_absolute_error(y_testing, y_predicted)
#print(mae)
with open('metrics.txt', 'a+') as f:
f.write('Root mean squared error:' + str(rmse) + '\n')
f.write('Mean absolute error:' + str(mae) + '\n')
#count = [float(line) for line in f if line]
#builds = list(range(1, len(count)))
with open('metric.txt', 'a+') as f:
f.write(str(rmse) + '\n')
with open('metric.txt') as file:
y_rmse = [float(line) for line in file if line]
x_builds = list(range(1, len(y_rmse) + 1))
plt.xlabel('Build')
plt.ylabel('RMSE')
plt.plot(x_builds, y_rmse, label='RMSE')
plt.legend()
plt.show()
plt.savefig('metrics.png')