2022-05-14 15:03:29 +02:00
|
|
|
import csv
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
import numpy as np
|
2022-05-14 15:03:29 +02:00
|
|
|
import pandas
|
|
|
|
import sklearn
|
|
|
|
from sklearn import metrics, preprocessing
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
from InitialStateFactory import InitialStateFactory
|
|
|
|
from data.ClientParams import ClientParams
|
|
|
|
from data.Order import Order
|
|
|
|
from data.enum.CompanySize import CompanySize
|
|
|
|
from data.enum.Priority import Priority
|
2022-05-22 16:27:36 +02:00
|
|
|
from util.ClientParamsFactory import ClientParamsFactory
|
2022-05-14 15:03:29 +02:00
|
|
|
|
|
|
|
|
|
|
|
class DecisionTree:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def generate_data(self, generator: ClientParamsFactory, n:int):
|
|
|
|
header = ['DELAY',
|
|
|
|
'PAYED',
|
|
|
|
'NET-WORTH',
|
|
|
|
'INFLUENCE',
|
|
|
|
'SKARBOWKA',
|
|
|
|
'MEMBER',
|
|
|
|
'HAT',
|
|
|
|
'SIZE']
|
|
|
|
|
|
|
|
with open("data/TEST/generatedData.csv", 'w', newline='') as file:
|
|
|
|
writer = csv.writer(file)
|
|
|
|
|
|
|
|
writer.writerow(header)
|
|
|
|
|
|
|
|
for i in range(n):
|
|
|
|
data = generator.get_client_params()
|
|
|
|
|
|
|
|
writer.writerow([data.payment_delay,
|
|
|
|
data.payed,
|
|
|
|
data.net_worth,
|
|
|
|
data.infuence_rate,
|
|
|
|
data.is_skarbowka,
|
|
|
|
data.membership,
|
|
|
|
data.is_hat,
|
|
|
|
data.company_size])
|
|
|
|
|
|
|
|
file.close()
|
|
|
|
|
|
|
|
def get_normalized_data(self, X):
|
|
|
|
label_BP = preprocessing.LabelEncoder()
|
|
|
|
label_BP.fit(
|
|
|
|
['CompanySize.NO', 'CompanySize.SMALL', 'CompanySize.NORMAL', 'CompanySize.BIG', 'CompanySize.HUGE',
|
|
|
|
'CompanySize.GIGANTISHE'])
|
|
|
|
X[:, 7] = label_BP.transform(X[:, 7])
|
|
|
|
|
|
|
|
return X
|
|
|
|
|
|
|
|
def print_logs(self, x, y, prediction):
|
|
|
|
for i in range(len(prediction)):
|
|
|
|
print("{}. {} \n predicted: {}, actual: {}".format(i, x[i, :], prediction[i], y[i]))
|
|
|
|
|
|
|
|
print("\nDecisionTrees's Accuracy: ", metrics.accuracy_score(y, prediction))
|
|
|
|
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
def get_data_good(self, orders: [Order]) -> [Order]:
|
|
|
|
|
|
|
|
n_array_input = []
|
|
|
|
for i in range(len(orders)):
|
|
|
|
o:Order = orders[i]
|
|
|
|
cp: ClientParams = o.client_params
|
|
|
|
pom = []
|
|
|
|
|
|
|
|
pom.append(cp.payment_delay)
|
|
|
|
pom.append(cp.payed)
|
|
|
|
pom.append(cp.net_worth)
|
|
|
|
pom.append(cp.infuence_rate)
|
|
|
|
pom.append(cp.is_skarbowka)
|
|
|
|
pom.append(cp.membership)
|
|
|
|
pom.append(cp.is_hat)
|
|
|
|
|
|
|
|
size: CompanySize = cp.company_size
|
|
|
|
if(size == CompanySize.NO):
|
|
|
|
pom.append(0)
|
|
|
|
if (size == CompanySize.SMALL):
|
|
|
|
pom.append(1)
|
|
|
|
if (size == CompanySize.NORMAL):
|
|
|
|
pom.append(2)
|
|
|
|
if (size == CompanySize.BIG):
|
|
|
|
pom.append(3)
|
|
|
|
if (size == CompanySize.HUGE):
|
|
|
|
pom.append(4)
|
|
|
|
if (size == CompanySize.GIGANTISHE):
|
|
|
|
pom.append(5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_array_input.append(pom)
|
|
|
|
|
|
|
|
n_array = np.array(n_array_input)
|
|
|
|
# print(n_array)
|
|
|
|
|
|
|
|
# print(n_array[0])
|
|
|
|
tree = self.get_decision_tree()
|
|
|
|
priority = tree.predict(n_array)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(orders)):
|
|
|
|
print(orders[i].priority)
|
|
|
|
orders[i].priority = priority[i]
|
|
|
|
|
|
|
|
if priority[i] == "LOW":
|
|
|
|
orders[i].priority = Priority.LOW
|
|
|
|
if priority[i] == "MEDIUM":
|
|
|
|
orders[i].priority = Priority.MEDIUM
|
|
|
|
if priority[i] == "HIGH":
|
|
|
|
orders[i].priority = Priority.HIGH
|
|
|
|
|
|
|
|
print(orders[i].priority)
|
|
|
|
|
|
|
|
|
|
|
|
return orders
|
2022-05-14 15:03:29 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_decision_tree(self) -> DecisionTreeClassifier:
|
|
|
|
data_input = pandas.read_csv('data/TEST/importedData.csv', delimiter=",")
|
|
|
|
|
|
|
|
X_headers = ['DELAY', 'PAYED', 'NET-WORTH', 'INFLUENCE', 'SKARBOWKA', 'MEMBER', 'HAT', 'SIZE']
|
|
|
|
|
|
|
|
X = data_input[X_headers].values
|
|
|
|
Y = data_input["PRIORITY"]
|
|
|
|
|
|
|
|
X = self.get_normalized_data(X)
|
|
|
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, train_size=0.8)
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
# print(len(X_train[0]))
|
|
|
|
# print(X_train[0])
|
|
|
|
|
2022-05-14 15:03:29 +02:00
|
|
|
drugTree = DecisionTreeClassifier(criterion="entropy", max_depth=4)
|
|
|
|
|
|
|
|
clf = drugTree.fit(X_train, y_train)
|
|
|
|
predicted = drugTree.predict(X_test)
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
# print(type(X_test))
|
|
|
|
|
2022-05-14 15:03:29 +02:00
|
|
|
y_test = y_test.to_list()
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
# self.print_logs(X_test, y_test, predicted)
|
|
|
|
|
|
|
|
# print(sklearn.tree.export_text(clf, feature_names=X_headers))
|
|
|
|
|
|
|
|
return drugTree
|
|
|
|
|
2022-05-14 15:03:29 +02:00
|
|
|
|
|
|
|
|
2022-06-09 21:54:18 +02:00
|
|
|
# kurwa = DecisionTree()
|
|
|
|
# kurwa.get_data_good(InitialStateFactory.generate_order_list(50))
|