23 lines
676 B
Python
23 lines
676 B
Python
|
from joblib import dump, load
|
||
|
from sklearn import tree
|
||
|
from sklearn.feature_extraction import DictVectorizer
|
||
|
|
||
|
# X is a list of dictionaries with samples, Y is a list of samples' results
|
||
|
X = list()
|
||
|
Y = list()
|
||
|
|
||
|
# TODO: load training data
|
||
|
|
||
|
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
|
||
|
vec = DictVectorizer()
|
||
|
|
||
|
# create and run Tree Clasifier upon provided data
|
||
|
clf = tree.DecisionTreeClassifier(max_depth=3)
|
||
|
clf = clf.fit(vec.fit_transform(X).toarray(), Y)
|
||
|
|
||
|
# save decision tree to file
|
||
|
dump(clf, 'decision_tree.joblib')
|
||
|
|
||
|
# print a tree (not necessary)
|
||
|
print(tree.export_text(clf, feature_names=vec.get_feature_names()))
|