sztuczna_inteligencja_2023_.../machine_learning/decisionTree.py
2023-05-28 02:50:07 +02:00

49 lines
1.4 KiB
Python

import os
from trainingData import TrainingData
from sklearn import tree
import joblib
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import numpy as np
def _read_training_data() -> TrainingData:
attributes = []
classes = []
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
file = open(os.path.join(location, 'training_data.csv'))
lines = file.readlines()[1:]
file.close()
for line in lines:
actual_row = line.replace('\n', '')
values = actual_row.split(',')
line_attributes = values[:-1]
line_class = values[-1]
attributes.append(line_attributes)
classes.append(line_class)
return TrainingData(attributes, classes)
trainning_data = _read_training_data()
X = trainning_data.attributes
Y = trainning_data.classes
le_shape = LabelEncoder()
le_flexibility = LabelEncoder()
le_color = LabelEncoder()
le_shape.fit([x[0] for x in X])
le_flexibility.fit([x[3] for x in X])
le_color.fit([x[4] for x in X])
X_encoded = np.array([
[le_shape.transform([x[0]])[0], x[1], x[2], le_flexibility.transform([x[3]])[0], le_color.transform([x[4]])[0]]
for x in X
])
encoder = OneHotEncoder(categories='auto', sparse=False)
X_encoded = encoder.fit_transform(X_encoded)
model = tree.DecisionTreeClassifier()
model.fit(X_encoded, Y)
joblib.dump(model, 'model.pkl')