Zaktualizuj 'evaluation.py'
This commit is contained in:
parent
5ae45d867e
commit
3351af1628
@ -5,13 +5,16 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn import preprocessing
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from datetime import datetime
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn import svm, datasets
|
||||
from sklearn.metrics import classification_report
|
||||
scaler = StandardScaler()
|
||||
|
||||
EPOCHS = int(sys.argv[1])
|
||||
|
||||
# Model
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
@ -54,13 +57,12 @@ X_testing = torch.from_numpy(X_testing.astype(np.float32))
|
||||
y_training = torch.from_numpy(y_training.astype(np.float32))
|
||||
y_testing = torch.from_numpy(y_testing.astype(np.float32))
|
||||
|
||||
|
||||
model = Model()
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# Trening
|
||||
num_epochs = EPOCHS
|
||||
num_epochs = 1000
|
||||
for epoch in range(num_epochs):
|
||||
y_predicted = model(X_training)
|
||||
loss = criterion(y_predicted,y_training)
|
||||
@ -75,8 +77,30 @@ with torch.no_grad():
|
||||
y_predicted = model(X_testing)
|
||||
y_predicted_cls = y_predicted.round()
|
||||
acc = y_predicted_cls.eq(y_testing).sum()/float(y_testing.shape[0])
|
||||
print(f'{acc:.4f}')
|
||||
result = open("output",'w+')
|
||||
result.write(f'{y_predicted}')
|
||||
#print(f'{acc:.4f}')
|
||||
|
||||
torch.save(model, "modelP.pkl")
|
||||
rmse = mean_squared_error(y_testing, y_predicted)
|
||||
#print(rmse)
|
||||
|
||||
mae = mean_absolute_error(y_testing, y_predicted)
|
||||
#print(mae)
|
||||
|
||||
with open('metrics.txt', 'a+') as f:
|
||||
f.write('Root mean squared error:' + str(rmse) + '\n')
|
||||
f.write('Mean absolute error:' + str(mae) + '\n')
|
||||
#count = [float(line) for line in f if line]
|
||||
#builds = list(range(1, len(count)))
|
||||
|
||||
with open('metric.txt', 'a+') as f:
|
||||
f.write(str(rmse) + '\n')
|
||||
|
||||
with open('metric.txt') as file:
|
||||
y_rmse = [float(line) for line in file if line]
|
||||
x_builds = list(range(1, len(y_rmse) + 1))
|
||||
|
||||
plt.xlabel('Build')
|
||||
plt.ylabel('RMSE')
|
||||
plt.plot(x_builds, y_rmse, label='RMSE')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
plt.savefig('metrics.png')
|
Loading…
Reference in New Issue
Block a user