Update 'evaluate.py'
Some checks failed
s449288-evaluation/pipeline/head There was a failure building this commit
s449288-training/pipeline/head There was a failure building this commit

This commit is contained in:
Kacper Dudzic 2022-04-25 23:05:50 +02:00
parent 703918a9c4
commit fcede41857

View File

@ -1,36 +1,36 @@
import tensorflow as tf
from keras.models import load_model
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
# Załadowanie modelu z pliku
model = keras.models.load_model('lego_reg_model')
# Załadowanie zbioru testowego
test_piece_counts = np.array(data_test['piece_count'])
test_prices = np.array(data_test['list_price'])
# Prosta ewaluacja (mean absolute error)
test_results = model.evaluate(
test_piece_counts,
test_prices, verbose=0)
# Zapis wartości liczbowej metryki do pliku
with open('eval_results.txt', 'a+') as f:
f.write(test_results)
# Wygenerowanie i zapisanie do pliku wykresu
with open('eval_results.txt') as f:
scores = []
for line in f:
scores.append(float(line))
builds = list(range(1, len(scores) + 1))
plot = plt.plot(builds, scores)
plt.xlabel('Build number')
plt.xticks(range(1, len(scores) + 1))
plt.ylabel('Mean absolute error')
plt.title('Model error by build')
plt.savefig('error_plot.jpg')
plt.show()
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
# Załadowanie modelu z pliku
model = keras.models.load_model('lego_reg_model')
# Załadowanie zbioru testowego
test_piece_counts = np.array(data_test['piece_count'])
test_prices = np.array(data_test['list_price'])
# Prosta ewaluacja (mean absolute error)
test_results = model.evaluate(
test_piece_counts,
test_prices, verbose=0)
# Zapis wartości liczbowej metryki do pliku
with open('eval_results.txt', 'a+') as f:
f.write(test_results)
# Wygenerowanie i zapisanie do pliku wykresu
with open('eval_results.txt') as f:
scores = []
for line in f:
scores.append(float(line))
builds = list(range(1, len(scores) + 1))
plot = plt.plot(builds, scores)
plt.xlabel('Build number')
plt.xticks(range(1, len(scores) + 1))
plt.ylabel('Mean absolute error')
plt.title('Model error by build')
plt.savefig('error_plot.jpg')
plt.show()