20 lines
930 B
Python
20 lines
930 B
Python
import pandas as pd
|
|
from sklearn import tree
|
|
from sklearn.preprocessing import LabelEncoder
|
|
import graphviz
|
|
|
|
def treelearn():
|
|
train_data = pd.read_csv('./data_set.csv')
|
|
attributes = train_data.drop('collect', axis='columns')
|
|
e_type = LabelEncoder()
|
|
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
|
|
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
|
|
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
|
|
label_names = ['collect', 'no-collect']
|
|
label = train_data['collect']
|
|
classifier = tree.DecisionTreeClassifier()
|
|
classifier.fit(attr_encoded.values, label)
|
|
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
|
|
graph = graphviz.Source(dot_data)
|
|
graph.render('collect')
|
|
return classifier |