disabled tree plotting; added tree files to .gitignore
This commit is contained in:
parent
e84e71529f
commit
7a5ce8d1cd
2
.gitignore
vendored
2
.gitignore
vendored
@ -142,4 +142,4 @@ dmypy.json
|
||||
cython_debug/
|
||||
|
||||
# local sandbox
|
||||
sandbox/
|
||||
sandbox
|
@ -40,12 +40,12 @@ class DecisionTree:
|
||||
print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
||||
|
||||
# plot a tree (not necessary)
|
||||
fig = pyplot.figure(figsize=(50, 40))
|
||||
_ = tree.plot_tree(self.clf,
|
||||
feature_names=self.vec.get_feature_names(),
|
||||
class_names=self.clf.classes_,
|
||||
filled=True)
|
||||
fig.savefig("decistion_tree.png")
|
||||
# fig = pyplot.figure(figsize=(50, 40))
|
||||
# _ = tree.plot_tree(self.clf,
|
||||
# feature_names=self.vec.get_feature_names(),
|
||||
# class_names=self.clf.classes_,
|
||||
# filled=True)
|
||||
# fig.savefig("decistion_tree.png")
|
||||
|
||||
def save(self):
|
||||
dump(self.clf, 'decision_tree.joblib')
|
||||
@ -88,4 +88,4 @@ if __name__ == "__main__":
|
||||
decision_tree = DecisionTree()
|
||||
decision_tree.build("training_set.txt", 15)
|
||||
decision_tree.test()
|
||||
# decision_tree.save()
|
||||
decision_tree.save()
|
||||
|
Loading…
Reference in New Issue
Block a user