20 lines
486 B
Python
20 lines
486 B
Python
|
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
def train_decision_tree(data):
|
||
|
X = data.drop(columns=["podlac"])
|
||
|
y = data["podlac"]
|
||
|
model = DecisionTreeClassifier()
|
||
|
model.fit(X, y)
|
||
|
return model, X.columns
|
||
|
|
||
|
|
||
|
def predict(model, feature_columns, sample):
|
||
|
sample_df = pd.DataFrame([sample])
|
||
|
sample_df = pd.get_dummies(sample_df)
|
||
|
sample_df = sample_df.reindex(columns=feature_columns, fill_value=0)
|
||
|
return model.predict(sample_df)
|
||
|
|