95 lines
2.5 KiB
Python
95 lines
2.5 KiB
Python
import os
|
|
from trainingData import TrainingData
|
|
from sklearn import tree
|
|
import joblib
|
|
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
|
|
import numpy as np
|
|
|
|
def _read_training_data() -> TrainingData:
|
|
attributes = []
|
|
classes = []
|
|
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
|
file = open(os.path.join(location, 'training_data.csv'))
|
|
lines = file.readlines()[1:]
|
|
file.close()
|
|
for line in lines:
|
|
actual_row = line.replace('\n', '')
|
|
values = actual_row.split(',')
|
|
line_attributes = values[:-1]
|
|
line_class = values[-1]
|
|
attributes.append(line_attributes)
|
|
classes.append(line_class.strip())
|
|
return TrainingData(attributes, classes)
|
|
|
|
def _attributes_to_floats(attributes: list[str]) -> list[float]:
|
|
output: list[float] = []
|
|
if attributes[0] == 'Longitiudonal':
|
|
output.append(0)
|
|
elif attributes[0] == 'Round':
|
|
output.append(1)
|
|
elif attributes[0] == 'Flat':
|
|
output.append(2)
|
|
elif attributes[0] == 'Irregular':
|
|
output.append(3)
|
|
|
|
if attributes[1] == 'Low':
|
|
output.append(0)
|
|
elif attributes[1] == 'Medium':
|
|
output.append(1)
|
|
elif attributes[1] == 'High':
|
|
output.append(2)
|
|
|
|
|
|
if attributes[2] == "Yes":
|
|
output.append(0)
|
|
else:
|
|
output.append(1)
|
|
|
|
if attributes[3] == 'Low':
|
|
output.append(0)
|
|
elif attributes[3] == 'Medium':
|
|
output.append(1)
|
|
elif attributes[3] == 'High':
|
|
output.append(2)
|
|
|
|
if attributes[4] == 'Low':
|
|
output.append(0)
|
|
elif attributes[4] == 'Medium':
|
|
output.append(1)
|
|
elif attributes[4] == 'High':
|
|
output.append(2)
|
|
|
|
if attributes[5] == 'Transparent':
|
|
output.append(0)
|
|
elif attributes[5] == 'Light':
|
|
output.append(1)
|
|
elif attributes[5] == 'Dark':
|
|
output.append(2)
|
|
elif attributes[5] == "Colorful":
|
|
output.append(3)
|
|
|
|
if attributes[6] == 'Low':
|
|
output.append(0)
|
|
elif attributes[6] == 'Medium':
|
|
output.append(1)
|
|
elif attributes[6] == 'High':
|
|
output.append(2)
|
|
|
|
if attributes[7] == "Yes":
|
|
output.append(0)
|
|
else:
|
|
output.append(1)
|
|
return output
|
|
|
|
|
|
|
|
trainning_data = _read_training_data()
|
|
|
|
X = trainning_data.attributes
|
|
Y = trainning_data.classes
|
|
|
|
|
|
model = tree.DecisionTreeClassifier()
|
|
encoded = [_attributes_to_floats(x) for x in X]
|
|
dtc = model.fit(encoded, Y)
|
|
joblib.dump(model, 'model.pkl') |