diff --git a/ID3.py b/ID3.py new file mode 100644 index 0000000..cdff71d --- /dev/null +++ b/ID3.py @@ -0,0 +1,185 @@ +import pandas as pd +import numpy as np +from pprint import pprint +import dataset + +training_data = pd.DataFrame(data=dataset.training_data, columns=dataset.header) +testing_data = pd.DataFrame(data=dataset.testing_data, columns=dataset.header) + + +def entropy(target_col): + """ + Obliczenie warości entropii dla wskazanej kolumny + """ + values, counts = np.unique(target_col, return_counts=True) + entropy = np.sum( + [(-counts[i] / np.sum(counts)) * np.log2(counts[i] / np.sum(counts)) for i in range(len(values))]) + return entropy + + +def info_gain(data, split_attribute_name, target_name="label"): + """ + Obliczenie wartości przyrostu informacji dla wskazanego atrybutu (split_attribute_name) + w podanym zbiorze (data) + """ + + # Wartość entropii zbioru + total_entropy = entropy(data[target_name]) + + # Wyodrębnienie poszczególnych "podzbiorów" + vals, counts = np.unique(data[split_attribute_name], return_counts=True) + + # Średnia ważona entropii każdego podzbioru + weighted_entropy = np.sum( + [(counts[i] / np.sum(counts)) * entropy(data.where(data[split_attribute_name] == vals[i]).dropna()[target_name]) + for i in range(len(vals))]) + + # Przyrost informacji + information_gain = total_entropy - weighted_entropy + + return information_gain + + +def ID3(data, original_data, features, target_attribute_name="label", parent_node_class=None): + """ + Algorytm ID3 + + parametry: + data zbiór danych, dla którego poszukujemy drzewa decyzyjnego + original_data oryginalny zbiór danych (zwracany gdy data == None) + features lista atrybutów wejściowego zbioru + target_attribute_name docelowy atrybut, który chcemy przewidzieć + parent_node_class nadrzędna wartość + """ + + # Jeżeli wszystkie atrybuty są takie same, zwracamy liść z pierwszą napotkaną wartością + + if len(np.unique(data[target_attribute_name])) <= 1: + return np.unique(data[target_attribute_name])[0] + + elif len(data) == 0: + return np.unique(original_data[target_attribute_name])[ + np.argmax(np.unique(original_data[target_attribute_name], return_counts=True)[1])] + + elif len(features) == 0: + return parent_node_class + + else: + + # Aktualizacja nadrzędnej wartości + parent_node_class = np.unique(data[target_attribute_name])[ + np.argmax(np.unique(data[target_attribute_name], return_counts=True)[1])] + + # Obliczenie przyrostu informacji dla każdego potencjalnego atrybutu, + # według którego nastąpi podział zbioru + item_values = [info_gain(data, feature, target_attribute_name) for feature in + features] + + # Najlepszym atrybutem jest ten o największym przyroście informacji + best_feature_index = np.argmax(item_values) + best_feature = features[best_feature_index] + + # Struktura drzewa + tree = {best_feature: {}} + + # Aktualizacja zbioru atrybutów + features = [i for i in features if i != best_feature] + + # Dla każdej wartości wybranego atrybutu budujemy kolejne poddrzewo + for value in np.unique(data[best_feature]): + + sub_data = data.where(data[best_feature] == value).dropna() + subtree = ID3(sub_data, data, features, target_attribute_name, parent_node_class) + + tree[best_feature][value] = subtree + + return (tree) + + +def predict(query, tree, default='beetroot'): + """ + Przeszukiwanie drzewa w celu przewidzenia wartości atrybutu "label". + W przypadku, gdy dane wejściowe nie pokrywają się z żadnymi wartościami w drzewie + (np pH ziemi zostanie sklasyfikowane jako 'strongly acidic', a dane uczące nie obejmują rekordów dla takiej wartości), + wówczas przewidywana zostaje wartość domyślna, w tym przypadku jest to burak jako warzywo o najmniejszych wymaganiach. + """ + + for key in list(query.keys()): + if key in list(tree.keys()): + try: + result = tree[key][query[key]] + except: + return default + result = tree[key][query[key]] + if isinstance(result, dict): + return predict(query, result) + + else: + return result + + +def test(data, tree): + # Wartości docelowych atrybutów (nazwy warzyw) zostają usunięte + queries = data.iloc[:, :-1].to_dict(orient="records") + + # Przewidywane wartości atrybutów + predicted = pd.DataFrame(columns=["predicted"]) + + # Obliczenie precyzji przewidywań + for i in range(len(data)): + predicted.loc[i, "predicted"] = predict(queries[i], tree, 'beetroot') + print('Precyzja przewidywań: ', (np.sum(predicted["predicted"] == data["label"]) / len(data)) * 100, '%') + + +def predict_data(data): + """ + Funkcja dostosowana do formatu danych, jakimi dysponuje traktor + 'data' jest tutaj listą, która zostaje przekonwertowana do postaci słownika, + aby możliwe było wywołanie procedury 'predict'. + Wyniki zostają zwrócone w postaci listy. + """ + parse_data = [[data[0], categorize_pH(data[1]), categorize_dry_level(data[2]), '']] + #print(parse_data) + + queries = pd.DataFrame(data=parse_data, columns=dataset.header) + predicted = pd.DataFrame(columns=["predicted"]) + dict = queries.iloc[:, :-1].to_dict(orient="records") + + for i in range(len(parse_data)): + predicted.loc[i, "predicted"] = predict(dict[i], tree, 'beetroot') + + predicted_list = predicted.values.tolist() + print("Planted: ", predicted_list[0][0]) + + +def categorize_pH(pH): + if pH <= 4.5: + return 'strongly acidic' + if 4.5 < pH <= 5.5: + return 'acidic' + if 5.5 < pH <= 6.5: + return 'slightly acidic' + if 6.5 < pH <= 7.2: + return 'neutral' + if 7.2 < pH: + return 'alkaline' + + +def categorize_dry_level(dry_level): + if dry_level <= 0.1: + return 'soaking wet' + if 0.1 < dry_level <= 0.4: + return 'wet' + if 0.4 < dry_level <= 0.6: + return 'medium wet' + if 0.6 < dry_level <= 0.8: + return 'dry' + if 0.8 < dry_level: + return 'very dry' + + +# tworzenie, wyświetlanie i testowanie drzewa + +tree = ID3(training_data, training_data, training_data.columns[:-1]) +#pprint(tree) +#test(testing_data, tree) diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..a94ddac --- /dev/null +++ b/dataset.py @@ -0,0 +1,71 @@ +header = ['previous', 'soil pH', 'dry level', 'label'] + +training_data = [ + ['carrot', 'alkaline', 'dry', 'beetroot'], + ['carrot', 'slightly acidic', 'dry', 'beetroot'], + ['cabbage', 'alkaline', 'dry', 'beetroot'], + ['none', 'alkaline', 'dry', 'beetroot'], + ['carrot', 'slightly acidic', 'medium wet', 'beetroot'], + ['none', 'slightly acidic', 'dry', 'beetroot'], + ['pumpkin', 'neutral', 'dry', 'beetroot'], + ['beetroot', 'neutral', 'dry', 'beetroot'], + ['cabbage', 'alkaline', 'medium wet', 'beetroot'], + ['none', 'slightly acidic', 'medium wet', 'beetroot'], + ['cabbage', 'acidic', 'dry', 'carrot'], + ['none', 'acidic', 'medium wet', 'carrot'], + ['carrot', 'neutral', 'dry', 'carrot'], + ['beetroot', 'slightly acidic', 'dry', 'carrot'], + ['pumpkin', 'acidic', 'medium wet', 'carrot'], + ['beetroot', 'acidic', 'medium wet', 'carrot'], + ['carrot', 'neutral', 'dry', 'carrot'], + ['pumpkin', 'slightly acidic', 'medium wet', 'carrot'], + ['beetroot', 'neutral', 'wet', 'pumpkin'], + ['none', 'neutral', 'wet', 'pumpkin'], + ['carrot', 'slightly acidic', 'wet', 'pumpkin'], + ['pumpkin', 'neutral', 'wet', 'pumpkin'], + ['cabbage', 'slightly acidic', 'medium wet', 'pumpkin'], + ['carrot', 'neutral', 'wet', 'pumpkin'], + ['cabbage', 'neutral', 'wet', 'pumpkin'], + ['none', 'slightly acidic', 'wet', 'pumpkin'], + ['beetroot', 'slightly acidic', 'medium wet', 'pumpkin'], + ['carrot', 'neutral', 'medium wet', 'cabbage'], + ['pumpkin', 'alkaline', 'wet', 'cabbage'], + ['none', 'alkaline', 'medium wet', 'cabbage'], + ['beetroot', 'neutral', 'medium wet', 'cabbage'], + ['cabbage', 'slightly acidic', 'wet', 'cabbage'], + ['none', 'neutral', 'medium wet', 'cabbage'], + ['cabbage', 'neutral', 'medium wet', 'cabbage'], + ['carrot', 'alkaline', 'wet', 'cabbage'], + ['none', 'alkaline', 'wet', 'cabbage'], + ['pumpkin', 'neutral', 'medium wet', 'cabbage'], + ['carrot', 'neutral', 'soaking wet', 'none'], + ['beetroot', 'alkaline', 'very dry', 'none'], + ['none', 'alkaline', 'soaking wet', 'none'], + ['cabbage', 'acidic', 'medium wet', 'none'], + ['pumpkin', 'acidic', 'soaking wet', 'none'], + ['cabbage', 'slightly acidic', 'soaking wet', 'none'], + ['none', 'slightly acidic', 'soaking wet', 'none'], + ['carrot', 'neutral', 'very dry', 'none'], + ['carrot', 'acidic', 'medium wet', 'none'], + ['pumpkin', 'neutral', 'soaking wet', 'none'] +] + +testing_data = [ + + ['beetroot', 'neutral', 'dry', 'beetroot'], + ['cabbage', 'alkaline', 'medium wet', 'beetroot'], + ['none', 'slightly acidic', 'medium wet', 'beetroot'], + ['cabbage', 'acidic', 'dry', 'carrot'], + ['none', 'acidic', 'medium wet', 'carrot'], + ['carrot', 'neutral', 'dry', 'carrot'], + ['beetroot', 'neutral', 'wet', 'pumpkin'], + ['none', 'neutral', 'wet', 'pumpkin'], + ['carrot', 'slightly acidic', 'wet', 'pumpkin'], + ['carrot', 'neutral', 'medium wet', 'cabbage'], + ['pumpkin', 'alkaline', 'wet', 'cabbage'], + ['none', 'alkaline', 'medium wet', 'cabbage'], + ['carrot', 'neutral', 'soaking wet', 'none'], + ['beetroot', 'alkaline', 'very dry', 'none'], + ['none', 'alkaline', 'soaking wet', 'none'], + +] \ No newline at end of file