sztuczna_inteligencja_2023_.../machine_learning/decisionTree.py
2023-05-29 12:00:46 +02:00

95 lines
2.5 KiB
Python

import os
from trainingData import TrainingData
from sklearn import tree
import joblib
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import numpy as np
def _read_training_data() -> TrainingData:
attributes = []
classes = []
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
file = open(os.path.join(location, 'training_data.csv'))
lines = file.readlines()[1:]
file.close()
for line in lines:
actual_row = line.replace('\n', '')
values = actual_row.split(',')
line_attributes = values[:-1]
line_class = values[-1]
attributes.append(line_attributes)
classes.append(line_class.strip())
return TrainingData(attributes, classes)
def _attributes_to_floats(attributes: list[str]) -> list[float]:
output: list[float] = []
if attributes[0] == 'Longitiudonal':
output.append(0)
elif attributes[0] == 'Round':
output.append(1)
elif attributes[0] == 'Flat':
output.append(2)
elif attributes[0] == 'Irregular':
output.append(3)
if attributes[1] == 'Low':
output.append(0)
elif attributes[1] == 'Medium':
output.append(1)
elif attributes[1] == 'High':
output.append(2)
if attributes[2] == "Yes":
output.append(0)
else:
output.append(1)
if attributes[3] == 'Low':
output.append(0)
elif attributes[3] == 'Medium':
output.append(1)
elif attributes[3] == 'High':
output.append(2)
if attributes[4] == 'Low':
output.append(0)
elif attributes[4] == 'Medium':
output.append(1)
elif attributes[4] == 'High':
output.append(2)
if attributes[5] == 'Transparent':
output.append(0)
elif attributes[5] == 'Light':
output.append(1)
elif attributes[5] == 'Dark':
output.append(2)
elif attributes[5] == "Colorful":
output.append(3)
if attributes[6] == 'Low':
output.append(0)
elif attributes[6] == 'Medium':
output.append(1)
elif attributes[6] == 'High':
output.append(2)
if attributes[7] == "Yes":
output.append(0)
else:
output.append(1)
return output
trainning_data = _read_training_data()
X = trainning_data.attributes
Y = trainning_data.classes
model = tree.DecisionTreeClassifier()
encoded = [_attributes_to_floats(x) for x in X]
dtc = model.fit(encoded, Y)
joblib.dump(model, 'model.pkl')