Traktor/decisiontree.py

30 lines
617 B
Python

from sklearn.tree import DecisionTreeClassifier
import pandas as pd
def train_decision_tree(data):
X = data.drop(columns=["podlac"])
y = data["podlac"]
model = DecisionTreeClassifier()
model.fit(X, y)
return model, X.columns
def predict(model, feature_columns, sample):
sample_df = pd.DataFrame([sample])
sample_df = pd.get_dummies(sample_df)
sample_df = sample_df.reindex(columns=feature_columns, fill_value=0)
return model.predict(sample_df)
#Marchew = 1
#zmiemniaki = 2
#pomidor = 3
#salata = 4
#cebula = 5
#Papryka = 6
#Buraki = 7
#Bruksela = 8
#Rzepak = 9
#Szpinak = 10