sztuczna_inteligencja_2023_.../machine_learning/decisionTree.py

49 lines
1.4 KiB
Python
Raw Normal View History

2023-05-27 11:34:26 +02:00
import os
from trainingData import TrainingData
2023-05-28 02:50:07 +02:00
from sklearn import tree
import joblib
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import numpy as np
2023-05-27 11:34:26 +02:00
def _read_training_data() -> TrainingData:
2023-05-28 02:50:07 +02:00
attributes = []
classes = []
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
file = open(os.path.join(location, 'training_data.csv'))
2023-05-27 11:34:26 +02:00
lines = file.readlines()[1:]
file.close()
for line in lines:
actual_row = line.replace('\n', '')
values = actual_row.split(',')
line_attributes = values[:-1]
line_class = values[-1]
attributes.append(line_attributes)
classes.append(line_class)
return TrainingData(attributes, classes)
2023-05-28 02:50:07 +02:00
2023-05-27 11:34:26 +02:00
trainning_data = _read_training_data()
2023-05-28 02:50:07 +02:00
X = trainning_data.attributes
Y = trainning_data.classes
le_shape = LabelEncoder()
le_flexibility = LabelEncoder()
le_color = LabelEncoder()
le_shape.fit([x[0] for x in X])
le_flexibility.fit([x[3] for x in X])
le_color.fit([x[4] for x in X])
X_encoded = np.array([
[le_shape.transform([x[0]])[0], x[1], x[2], le_flexibility.transform([x[3]])[0], le_color.transform([x[4]])[0]]
for x in X
])
encoder = OneHotEncoder(categories='auto', sparse=False)
X_encoded = encoder.fit_transform(X_encoded)
model = tree.DecisionTreeClassifier()
model.fit(X_encoded, Y)
joblib.dump(model, 'model.pkl')