Traktor/decisiontree.py

20 lines
486 B
Python
Raw Normal View History

2024-05-18 22:51:09 +02:00
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
def train_decision_tree(data):
X = data.drop(columns=["podlac"])
y = data["podlac"]
model = DecisionTreeClassifier()
model.fit(X, y)
return model, X.columns
def predict(model, feature_columns, sample):
sample_df = pd.DataFrame([sample])
sample_df = pd.get_dummies(sample_df)
sample_df = sample_df.reindex(columns=feature_columns, fill_value=0)
return model.predict(sample_df)