2022-05-02 00:10:35 +02:00
|
|
|
#!/usr/bin/python
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
import pandas as pd
|
|
|
|
from sklearn import preprocessing
|
|
|
|
import numpy as np
|
|
|
|
from torch.autograd import Variable
|
|
|
|
from sklearn.metrics import accuracy_score, f1_score
|
|
|
|
from csv import DictWriter
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import sys
|
2022-05-02 01:33:14 +02:00
|
|
|
import os
|
2022-05-02 02:08:51 +02:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import json
|
2022-05-02 00:10:35 +02:00
|
|
|
|
2022-05-16 01:31:02 +02:00
|
|
|
|
2022-05-02 00:10:35 +02:00
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, input_dim):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.layer1 = nn.Linear(input_dim, 100)
|
|
|
|
self.layer2 = nn.Linear(100, 60)
|
|
|
|
self.layer3 = nn.Linear(60, 5)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = F.relu(self.layer1(x))
|
|
|
|
x = F.relu(self.layer2(x))
|
|
|
|
x = F.softmax(self.layer3(x)) # To check with the loss function
|
|
|
|
return x
|
|
|
|
|
2022-05-16 01:31:02 +02:00
|
|
|
|
2022-05-02 00:10:35 +02:00
|
|
|
def prepare_labels_features(dataset):
|
|
|
|
""" Label make column"""
|
|
|
|
le = preprocessing.LabelEncoder()
|
|
|
|
mark_column = np.array(dataset[:]['0'])
|
|
|
|
le.fit(mark_column)
|
|
|
|
|
|
|
|
print(list(le.classes_))
|
|
|
|
lab = le.transform(mark_column)
|
|
|
|
feat = dataset.drop(['0'], axis=1).to_numpy()
|
|
|
|
|
|
|
|
mm_scaler = preprocessing.MinMaxScaler()
|
|
|
|
feat = mm_scaler.fit_transform(feat)
|
|
|
|
|
|
|
|
return lab, feat
|
|
|
|
|
|
|
|
|
|
|
|
def print_metrics(test_labels, predictions):
|
|
|
|
# take column with max predicted score
|
|
|
|
f1 = f1_score(labels_test, np.argmax(predictions, axis=1), average='weighted')
|
|
|
|
accuracy = accuracy_score(test_labels, np.argmax(predictions, axis=1))
|
|
|
|
print(f"The F1_score metric is: {f1}")
|
|
|
|
print(f"The accuracy metric is: {accuracy}")
|
|
|
|
|
|
|
|
try:
|
2022-05-02 00:31:26 +02:00
|
|
|
build_number = sys.argv[1]
|
|
|
|
print(f"Build number: {build_number}")
|
2022-05-02 00:10:35 +02:00
|
|
|
field_names = ['BUILD_NUMBER', 'F1', 'ACCURACY']
|
|
|
|
dict = {'BUILD_NUMBER': build_number, 'F1': f1, 'ACCURACY': accuracy }
|
2022-05-02 01:33:14 +02:00
|
|
|
filename = "./metrics.csv"
|
|
|
|
file_exists = os.path.isfile(filename)
|
2022-05-02 00:10:35 +02:00
|
|
|
|
2022-05-02 01:33:14 +02:00
|
|
|
with open(filename, 'a') as metrics_file:
|
2022-05-02 00:10:35 +02:00
|
|
|
dictwriter_object = DictWriter(metrics_file, fieldnames=field_names)
|
2022-05-02 01:33:14 +02:00
|
|
|
if not file_exists:
|
|
|
|
dictwriter_object.writeheader()
|
2022-05-02 00:10:35 +02:00
|
|
|
dictwriter_object.writerow(dict)
|
|
|
|
metrics_file.close()
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
|
|
|
|
|
2022-05-02 02:08:51 +02:00
|
|
|
def draw_plot():
|
2022-05-02 02:26:02 +02:00
|
|
|
metrics = pd.read_csv('metrics.csv', delimiter=',', header=None)
|
2022-05-02 02:08:51 +02:00
|
|
|
build_axis = metrics[0][:]
|
|
|
|
plt.xlabel('Build')
|
|
|
|
plt.ylabel('Score')
|
|
|
|
plt.plot(build_axis, metrics[2][:], label='Accuracy')
|
|
|
|
plt.plot(build_axis, metrics[1][:], label='F1 Score')
|
|
|
|
plt.legend()
|
|
|
|
plt.show()
|
|
|
|
plt.savefig('metrics.png')
|
|
|
|
|
|
|
|
|
2022-05-02 00:10:35 +02:00
|
|
|
model = torch.load("CarPrices_pytorch_model.pkl")
|
|
|
|
cars_dev = pd.read_csv('./Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names=[str(i) for i in range(5)])
|
|
|
|
cars_dev = cars_dev.loc[(cars_dev['0'] == 'audi') | (cars_dev['0'] == 'bmw') | (cars_dev['0'] == 'ford') | (cars_dev['0'] == 'opel') | (cars_dev['0'] == 'volkswagen')]
|
|
|
|
labels_test, features_test = prepare_labels_features(cars_dev)
|
|
|
|
|
|
|
|
x_test = Variable(torch.from_numpy(features_test)).float()
|
|
|
|
pred = model(x_test)
|
|
|
|
pred = pred.detach().numpy()
|
2022-06-04 22:15:20 +02:00
|
|
|
# print_metrics(labels_test, pred)
|
2022-05-02 00:10:35 +02:00
|
|
|
|
2022-06-04 22:15:20 +02:00
|
|
|
# draw_plot()
|
2022-05-02 02:08:51 +02:00
|
|
|
|
2022-05-02 00:10:35 +02:00
|
|
|
|
|
|
|
|