diff --git a/decisiontree.py b/decisiontree.py index efa6805..0425d41 100644 --- a/decisiontree.py +++ b/decisiontree.py @@ -41,4 +41,14 @@ y = df['Decision'] dtree = DecisionTreeClassifier() dtree = dtree.fit(X, y) +data = tree.export_graphviz(dtree, out_file=None, feature_names=features) +graph = pydotplus.graph_from_dot_data(data) +graph.write_png('mydecisiontree.png') + +img = pltimg.imread('mydecisiontree.png') +imgplot = plt.imshow(img) +plt.show() + + + #print(dtree.predict([[0, 1, 0, 0, 0, 1]])) \ No newline at end of file