Utworzenie klasy TreeLearn oraz klasy Node

This commit is contained in:
Jakub 2023-05-21 18:46:38 +02:00
parent 3a8a7fdb14
commit 0099d64e54
5 changed files with 65 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

65
main.py
View File

@ -9,6 +9,71 @@ from pygame.locals import *
from datetime import datetime
class TreeLearn:
def __init__(self):
self.tree = None
def train(self, examples, attributes, default_class):
self.tree = self.build_tree(examples, attributes, default_class)
def build_tree(self, examples, attributes, default_class):
if not examples:
return Node(default_class)
if self.all_same_class(examples):
return Node(examples[0][-1])
if not attributes:
class_counts = self.get_class_counts(examples)
default_class = max(class_counts, key=class_counts.get)
return Node(default_class)
best_attribute = self.choose_attribute(examples, attributes)
root = Node(best_attribute)
attribute_values = self.get_attribute_values(examples, best_attribute)
for value in attribute_values:
new_examples = self.filter_examples(examples, best_attribute, value)
new_attributes = attributes[:]
new_attributes.remove(best_attribute)
new_default_class = max(self.get_class_counts(new_examples), key=lambda k: class_counts.get(k, 0))
subtree = self.build_tree(new_examples, new_attributes, new_default_class)
root.add_child(value, subtree)
return root
def all_same_class(self, examples):
return len(set([example[-1] for example in examples])) == 1
def get_class_counts(self, examples):
class_counts = {}
for example in examples:
class_label = example[-1]
class_counts[class_label] = class_counts.get(class_label, 0) + 1
return class_counts
def choose_attribute(self, examples, attributes):
# Placeholder for attribute selection logic
return attributes[0]
def get_attribute_values(self, examples, attribute):
return list(set([example[attribute] for example in examples]))
def filter_examples(self, examples, attribute, value):
return [example for example in examples if example[attribute] == value]
class Node:
def __init__(self, label):
self.label = label
self.children = {}
def add_child(self, value, child):
self.children[value] = child
class Game:
cell_size = 50
cell_number = 15 # horizontally