diff --git a/algorithms/bfs.py b/algorithms/bfs.py index b2ed3d3..959a684 100644 --- a/algorithms/bfs.py +++ b/algorithms/bfs.py @@ -51,7 +51,7 @@ def graphsearch(initial_state: State, map, goal_list, fringe: List[Node] = None, explored_states = set() fringe_states = set() - # root Node + # train Node fringe.append(Node(initial_state)) fringe_states.add((initial_state.row, initial_state.column, initial_state.direction)) @@ -71,7 +71,7 @@ def graphsearch(initial_state: State, map, goal_list, fringe: List[Node] = None, parent = element.parent while parent is not None: - # root's action will be None, don't add it + # train's action will be None, don't add it if parent.action is not None: actions_sequence.append(parent.action) parent = parent.parent diff --git a/algorithms/neural_network/data/test/grass/grass1.png b/algorithms/neural_network/data/test/grass/grass1.png new file mode 100644 index 0000000..dd88981 Binary files /dev/null and b/algorithms/neural_network/data/test/grass/grass1.png differ diff --git a/algorithms/neural_network/data/test/grass/grass2.png b/algorithms/neural_network/data/test/grass/grass2.png new file mode 100644 index 0000000..c07ab8f Binary files /dev/null and b/algorithms/neural_network/data/test/grass/grass2.png differ diff --git a/algorithms/neural_network/data/test/grass/grass3.png b/algorithms/neural_network/data/test/grass/grass3.png new file mode 100644 index 0000000..4fac085 Binary files /dev/null and b/algorithms/neural_network/data/test/grass/grass3.png differ diff --git a/algorithms/neural_network/data/test/grass/grass4.png b/algorithms/neural_network/data/test/grass/grass4.png new file mode 100644 index 0000000..74cb3d1 Binary files /dev/null and b/algorithms/neural_network/data/test/grass/grass4.png differ diff --git a/algorithms/neural_network/data/test/sand/sand.png b/algorithms/neural_network/data/test/sand/sand.png new file mode 100644 index 0000000..51c072d Binary files /dev/null and b/algorithms/neural_network/data/test/sand/sand.png differ diff --git a/algorithms/neural_network/data/test/tree/grass_with_tree.jpg b/algorithms/neural_network/data/test/tree/grass_with_tree.jpg new file mode 100644 index 0000000..56af6b8 Binary files /dev/null and b/algorithms/neural_network/data/test/tree/grass_with_tree.jpg differ diff --git a/algorithms/neural_network/data/test/water/water.png b/algorithms/neural_network/data/test/water/water.png new file mode 100644 index 0000000..28b45db Binary files /dev/null and b/algorithms/neural_network/data/test/water/water.png differ diff --git a/algorithms/neural_network/learnedNetwork.pt b/algorithms/neural_network/learnedNetwork.pt new file mode 100644 index 0000000..3f40e13 Binary files /dev/null and b/algorithms/neural_network/learnedNetwork.pt differ diff --git a/algorithms/neural_network/neural_network.py b/algorithms/neural_network/neural_network.py new file mode 100644 index 0000000..d17a223 --- /dev/null +++ b/algorithms/neural_network/neural_network.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class NeuralNetwork(nn.Module): + def __init__(self, num_classes=4): + super(NeuralNetwork, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) + self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.fc1 = nn.Linear(20*9*9, num_classes) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = x.reshape(x.shape[0], -1) + x = self.fc1(x) + + return x diff --git a/algorithms/neural_network/neural_network_interface.py b/algorithms/neural_network/neural_network_interface.py new file mode 100644 index 0000000..3a3b1c6 --- /dev/null +++ b/algorithms/neural_network/neural_network_interface.py @@ -0,0 +1,91 @@ +import torch +from common.constants import device, batch_size, num_epochs, learning_rate, setup_photos, id_to_class +from watersandtreegrass import WaterSandTreeGrass +from torch.utils.data import DataLoader +from neural_network import NeuralNetwork +from torchvision.io import read_image, ImageReadMode +import torch.nn as nn +from torch.optim import Adam + +CNN = NeuralNetwork().to(device) + + +def train(model): + model.train() + trainset = WaterSandTreeGrass('./data/train_csv_file.csv', './data/train/all', transform=setup_photos) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + + criterion = nn.CrossEntropyLoss() + optimizer = Adam(model.parameters(), lr=learning_rate) + + for epoch in range(num_epochs): + for batch_idx, (data, targets) in enumerate(train_loader): + data = data.to(device=device) + targets = targets.to(device=device) + + scores = model(data) + loss = criterion(scores, targets) + + optimizer.zero_grad() + loss.backward() + + optimizer.step() + + if epoch % 2 == 0: + print("epoch: %3d loss: %.4f" % (epoch, loss.item())) + + print("FINISHED!") + print("Checking accuracy.") + check_accuracy(train_loader) + + torch.save(model.state_dict(), "./learnedNetwork.pt") + + +def check_accuracy(loader): + num_correct = 0 + num_samples = 0 + model = NeuralNetwork() + + model.load_state_dict(torch.load("./learnedNetwork.pt")) + model = model.to(device) + + with torch.no_grad(): + model.eval() + for x, y in loader: + x = x.to(device=device) + y = y.to(device=device) + + scores = model(x) + + _, predictions = scores.max(1) + num_correct += (predictions == y).sum() + num_samples += predictions.size(0) + + print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}") + + +testset_loader = DataLoader( + WaterSandTreeGrass('./data/test_csv_file.csv', './data/test/all', transform=setup_photos), + batch_size=batch_size +) + + +def what_is_it(img_path): + image = read_image(img_path, mode=ImageReadMode.RGB) + image = setup_photos(image).unsqueeze(0) + model = NeuralNetwork() + + model.load_state_dict(torch.load("./learnedNetwork.pt")) + model = model.to(device) + image = image.to(device) + + with torch.no_grad(): + model.eval() + idx = int(model(image).argmax(dim=1)) + return id_to_class[idx] + + + +check_accuracy(testset_loader) + +print(what_is_it('./data/test/water/water.png')) diff --git a/algorithms/neural_network/watersandtreegrass.py b/algorithms/neural_network/watersandtreegrass.py new file mode 100644 index 0000000..835d540 --- /dev/null +++ b/algorithms/neural_network/watersandtreegrass.py @@ -0,0 +1,28 @@ +import torch +from torch.utils.data import Dataset +import pandas as pd +from torchvision.io import read_image, ImageReadMode +from common.helpers import createCSV +import os + + +class WaterSandTreeGrass(Dataset): + def __init__(self, annotations_file, img_dir, transform=None): + createCSV() + self.img_labels = pd.read_csv(annotations_file) + self.img_dir = img_dir + self.transform = transform + + def __len__(self): + return len(self.img_labels) + + def __getitem__(self, idx): + img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) + image = read_image(img_path, mode=ImageReadMode.RGB) + label = torch.tensor(int(self.img_labels.iloc[idx, 1])) + + if self.transform: + image = self.transform(image) + + return image, label + diff --git a/common/constants.py b/common/constants.py index caf995f..e346f82 100644 --- a/common/constants.py +++ b/common/constants.py @@ -1,4 +1,6 @@ from enum import Enum +import torchvision.transforms as transforms +import torch GAME_TITLE = 'WMICraft' WINDOW_HEIGHT = 800 @@ -67,3 +69,22 @@ ACTION = { BAR_ANIMATION_SPEED = 1 BAR_WIDTH_MULTIPLIER = 0.9 # (0;1> BAR_HEIGHT_MULTIPLIER = 0.1 + +#NEURAL_NETWORK +learning_rate = 0.001 +batch_size = 7 +num_epochs = 10 + +device = torch.device('cuda') +classes = ['grass', 'sand', 'tree', 'water'] + +setup_photos = transforms.Compose([ + transforms.Resize(36), + transforms.CenterCrop(36), + transforms.ToPILImage(), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) +]) + +id_to_class = {i: j for i, j in enumerate(classes)} +class_to_id = {value: key for key, value in id_to_class.items()} diff --git a/common/helpers.py b/common/helpers.py index dd13e31..22e44b7 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -1,5 +1,7 @@ import pygame -from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE, COLUMNS, ROWS +from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE, COLUMNS, ROWS, classes, class_to_id +import csv +import os def draw_text(text, color, surface, x, y, text_size=30, is_bold=False): @@ -12,6 +14,35 @@ def draw_text(text, color, surface, x, y, text_size=30, is_bold=False): textrect.topleft = (x, y) surface.blit(textobj, textrect) +def createCSV(): + train_csvfile = open('./data/train_csv_file.csv', 'w', newline="") + writer = csv.writer(train_csvfile) + writer.writerow(["filename", "type"]) + + train_data_path = './data/train' + test_data_path = './data/test' + + for class_name in classes: + class_dir = train_data_path + "/" + class_name + for filename in os.listdir(class_dir): + f = os.path.join(class_dir, filename) + if os.path.isfile(f): + writer.writerow([filename, class_to_id[class_name]]) + + test_csvfile = open('./data/test_csv_file.csv', 'w', newline="") + writer = csv.writer(test_csvfile) + writer.writerow(["filename", "type"]) + + for class_name in classes: + class_dir = test_data_path + "/" + class_name + for filename in os.listdir(class_dir): + f = os.path.join(class_dir, filename) + if os.path.isfile(f): + writer.writerow([filename, class_to_id[class_name]]) + + test_csvfile.close() + train_csvfile.close() + def print_numbers(): display_surface = pygame.display.get_surface() diff --git a/logic/health_bar.py b/logic/health_bar.py index c362fc7..0b8bc3a 100644 --- a/logic/health_bar.py +++ b/logic/health_bar.py @@ -46,7 +46,7 @@ class HealthBar: def heal(self, amount): if self.current_hp + amount < self.max_hp: self.current_hp += amount - elif self.current_hp + amount > self.max_hp: + elif self.current_hp + amount >= self.max_hp: self.current_hp = self.max_hp def show(self): diff --git a/requirements.txt b/requirements.txt index ad0de46..f2fe57a 100644 Binary files a/requirements.txt and b/requirements.txt differ