import os from trainingData import TrainingData from sklearn import tree import joblib from sklearn.preprocessing import LabelEncoder, OneHotEncoder import numpy as np def _read_training_data() -> TrainingData: attributes = [] classes = [] location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) file = open(os.path.join(location, 'training_data.csv')) lines = file.readlines()[1:] file.close() for line in lines: actual_row = line.replace('\n', '') values = actual_row.split(',') line_attributes = values[:-1] line_class = values[-1] attributes.append(line_attributes) classes.append(line_class.strip()) return TrainingData(attributes, classes) def _attributes_to_floats(attributes: list[str]) -> list[float]: output: list[float] = [] if attributes[0] == 'Longitiudonal': output.append(0) elif attributes[0] == 'Round': output.append(1) elif attributes[0] == 'Flat': output.append(2) elif attributes[0] == 'Irregular': output.append(3) if attributes[1] == 'Low': output.append(0) elif attributes[1] == 'Medium': output.append(1) elif attributes[1] == 'High': output.append(2) if attributes[2] == "Yes": output.append(0) else: output.append(1) if attributes[3] == 'Low': output.append(0) elif attributes[3] == 'Medium': output.append(1) elif attributes[3] == 'High': output.append(2) if attributes[4] == 'Low': output.append(0) elif attributes[4] == 'Medium': output.append(1) elif attributes[4] == 'High': output.append(2) if attributes[5] == 'Transparent': output.append(0) elif attributes[5] == 'Light': output.append(1) elif attributes[5] == 'Dark': output.append(2) elif attributes[5] == "Colorful": output.append(3) if attributes[6] == 'Low': output.append(0) elif attributes[6] == 'Medium': output.append(1) elif attributes[6] == 'High': output.append(2) if attributes[7] == "Yes": output.append(0) else: output.append(1) return output trainning_data = _read_training_data() X = trainning_data.attributes Y = trainning_data.classes model = tree.DecisionTreeClassifier() encoded = [_attributes_to_floats(x) for x in X] dtc = model.fit(encoded, Y) joblib.dump(model, 'model.pkl')