Inteligentna_smieciarka/treelearn.py

20 lines
930 B
Python
Raw Normal View History

2023-05-26 19:08:22 +02:00
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
def treelearn():
train_data = pd.read_csv('./data_set.csv')
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded.values, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
return classifier