big commit (ostroznie!)
This commit is contained in:
parent
32c50f27fc
commit
9a910a9284
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from matplotlib import pyplot
|
||||
from joblib import dump, load
|
||||
from sklearn import tree
|
||||
from sklearn.feature_extraction import DictVectorizer
|
||||
@ -38,6 +39,14 @@ class DecisionTree:
|
||||
# print a tree (not necessary)
|
||||
print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
||||
|
||||
# plot a tree (not necessary)
|
||||
fig = pyplot.figure(figsize=(50, 40))
|
||||
_ = tree.plot_tree(self.clf,
|
||||
feature_names=self.vec.get_feature_names(),
|
||||
class_names=self.clf.classes_,
|
||||
filled=True)
|
||||
fig.savefig("decistion_tree.png")
|
||||
|
||||
def save(self):
|
||||
dump(self.clf, 'decision_tree.joblib')
|
||||
dump(self.vec, 'dict_vectorizer.joblib')
|
||||
@ -79,4 +88,3 @@ if __name__ == "__main__":
|
||||
decision_tree = DecisionTree()
|
||||
decision_tree.build("training_set.txt", 15)
|
||||
decision_tree.test()
|
||||
decision_tree.save()
|
||||
|
Loading…
Reference in New Issue
Block a user