30 lines
617 B
Python
30 lines
617 B
Python
|
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
def train_decision_tree(data):
|
||
|
X = data.drop(columns=["podlac"])
|
||
|
y = data["podlac"]
|
||
|
model = DecisionTreeClassifier()
|
||
|
model.fit(X, y)
|
||
|
return model, X.columns
|
||
|
|
||
|
|
||
|
def predict(model, feature_columns, sample):
|
||
|
sample_df = pd.DataFrame([sample])
|
||
|
sample_df = pd.get_dummies(sample_df)
|
||
|
sample_df = sample_df.reindex(columns=feature_columns, fill_value=0)
|
||
|
return model.predict(sample_df)
|
||
|
|
||
|
|
||
|
#Marchew = 1
|
||
|
#zmiemniaki = 2
|
||
|
#pomidor = 3
|
||
|
#salata = 4
|
||
|
#cebula = 5
|
||
|
#Papryka = 6
|
||
|
#Buraki = 7
|
||
|
#Bruksela = 8
|
||
|
#Rzepak = 9
|
||
|
#Szpinak = 10
|