SI_InteligentnyWozekWidlowy/tree/DecisionTree.py

86 lines
2.6 KiB
Python
Raw Normal View History

2022-05-14 15:03:29 +02:00
import csv
import pandas
import sklearn
from sklearn import metrics, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
2022-05-22 16:27:36 +02:00
from util.ClientParamsFactory import ClientParamsFactory
2022-05-14 15:03:29 +02:00
class DecisionTree:
def __init__(self) -> None:
super().__init__()
def generate_data(self, generator: ClientParamsFactory, n:int):
header = ['DELAY',
'PAYED',
'NET-WORTH',
'INFLUENCE',
'SKARBOWKA',
'MEMBER',
'HAT',
'SIZE']
with open("data/TEST/generatedData.csv", 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header)
for i in range(n):
data = generator.get_client_params()
writer.writerow([data.payment_delay,
data.payed,
data.net_worth,
data.infuence_rate,
data.is_skarbowka,
data.membership,
data.is_hat,
data.company_size])
file.close()
def get_normalized_data(self, X):
label_BP = preprocessing.LabelEncoder()
label_BP.fit(
['CompanySize.NO', 'CompanySize.SMALL', 'CompanySize.NORMAL', 'CompanySize.BIG', 'CompanySize.HUGE',
'CompanySize.GIGANTISHE'])
X[:, 7] = label_BP.transform(X[:, 7])
return X
def print_logs(self, x, y, prediction):
for i in range(len(prediction)):
print("{}. {} \n predicted: {}, actual: {}".format(i, x[i, :], prediction[i], y[i]))
print("\nDecisionTrees's Accuracy: ", metrics.accuracy_score(y, prediction))
def get_decision_tree(self) -> DecisionTreeClassifier:
data_input = pandas.read_csv('data/TEST/importedData.csv', delimiter=",")
X_headers = ['DELAY', 'PAYED', 'NET-WORTH', 'INFLUENCE', 'SKARBOWKA', 'MEMBER', 'HAT', 'SIZE']
X = data_input[X_headers].values
Y = data_input["PRIORITY"]
X = self.get_normalized_data(X)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, train_size=0.8)
drugTree = DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf = drugTree.fit(X_train, y_train)
predicted = drugTree.predict(X_test)
y_test = y_test.to_list()
self.print_logs(X_test, y_test, predicted)
print(sklearn.tree.export_text(clf, feature_names=X_headers))
return drugTree