From 45aaf233f14666f0c054d3ebcb011e8e13e98719 Mon Sep 17 00:00:00 2001 From: Tomasz Sidoruk Date: Thu, 1 Jun 2023 11:10:14 +0200 Subject: [PATCH] add networks --- .idea/misc.xml | 2 +- NN/Generator.py | 84 ++++++++++++++++++++++++++++++++++++ NN/trainer.py | 47 +++++++++++++++++++++ main.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 NN/Generator.py create mode 100644 NN/trainer.py diff --git a/.idea/misc.xml b/.idea/misc.xml index d37d1d2..a0f56f8 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/NN/Generator.py b/NN/Generator.py new file mode 100644 index 0000000..cc54c28 --- /dev/null +++ b/NN/Generator.py @@ -0,0 +1,84 @@ +from PIL import Image +import random + +plants = [[], [], []] +plants[0].append(Image.open("w1.png")) +plants[0].append(Image.open("w2.png")) +plants[0].append(Image.open("w3.png")) +plants[1].append(Image.open("c1.png")) +plants[1].append(Image.open("c2.png")) +plants[1].append(Image.open("c3.png")) +plants[2].append(Image.open("ca1.png")) +plants[2].append(Image.open("ca2.png")) +plants[2].append(Image.open("ca3.png")) +b = [Image.open("b1.png").convert('RGBA'), Image.open("b2.png").convert('RGBA'), Image.open("b3.png").convert('RGBA')] + + +def generate(water, fertilizer, plantf): + new_im = None + if water == 1: + new_im = Image.new('RGB', (100, 100), + (160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10))) + tmp = plants[plantf][random.randint(0, 2)].resize( + (25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp) + if fertilizer: + tmp = b[random.randint(0, 2)].resize( + (20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp) + else: + if fertilizer: + new_im = Image.new('RGB', (100, 100), + ( + 50 + random.randint(-10, 10), 25 + random.randint(-10, 10), + 0 + random.randint(-10, 10))) + tmp = plants[plantf][random.randint(0, 2)].resize( + (25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp) + tmp = b[random.randint(0, 2)].resize( + (20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp) + + else: + if random.randint(0, 1) == 1: + new_im = Image.new('RGB', (100, 100), + (50 + random.randint(-10, 10), 25 + random.randint(-10, 10), + 0 + random.randint(-10, 10))) + else: + new_im = Image.new('RGB', (100, 100), + (160 + random.randint(-10, 10), 80 + random.randint(-10, 10), + 40 + random.randint(-10, 10))) + if random.randint(0, 1) == 1: # big + tmp = plants[plantf][random.randint(0, 2)].resize( + (75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp) + else: + tmp = plants[plantf][random.randint(0, 2)].resize( + (random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359)) + datas = tmp.getdata() + + new_image_data = [] + for item in datas: + # change all white (also shades of whites) pixels to yellow + if item[0] in list(range(190, 256)): + new_image_data.append( + (random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10))) + else: + new_image_data.append(item) + + # update image data + tmp.putdata(new_image_data) + new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp) + + return new_im + + +for x in range(0, 1000): + generate(0, 0, random.randint(0, 2)).save('datasets/00/' + str(x) + '.png') +for x in range(0, 1000): + generate(1, 0, random.randint(0, 2)).save('datasets/10/' + str(x) + '.png') +for x in range(0, 1000): + generate(0, 1, random.randint(0, 2)).save('datasets/01/' + str(x) + '.png') +for x in range(0, 1000): + generate(1, 1, random.randint(0, 2)).save('datasets/11/' + str(x) + '.png') + diff --git a/NN/trainer.py b/NN/trainer.py new file mode 100644 index 0000000..f27083b --- /dev/null +++ b/NN/trainer.py @@ -0,0 +1,47 @@ +import pathlib +import random + +import torch +from PIL.Image import Image +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from torchvision.transforms import Lambda + +device = torch.device('cuda') + +def train(model, dataset, n_iter=100, batch_size=2560000): + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + dl = DataLoader(dataset, batch_size=batch_size) + model.train() + for epoch in range(n_iter): + for images, targets in dl: + optimizer.zero_grad() + out = model(images.to(device)) + loss = criterion(out, targets.to(device)) + loss.backward() + optimizer.step() + if epoch % 10 == 0: + print('epoch: %3d loss: %.4f' % (epoch, loss)) + + +image_path_list = list(pathlib.Path('./').glob("*/*/*.png")) + +random_image_path = random.choice(image_path_list) +data_transform = transforms.Compose([ + transforms.Resize(size=(100, 100)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + Lambda(lambda x: x.flatten()) +]) + +train_data = datasets.ImageFolder(root="./datasets", + transform=data_transform, + target_transform=None) + +model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device) +model1.load_state_dict(torch.load("./trained")) +train(model1,train_data) + +torch.save(model1.state_dict(), "./trained") \ No newline at end of file diff --git a/main.py b/main.py index 8933ac0..9ded605 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,13 @@ import pygame +import pathlib +import random + +import torch +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from torchvision.transforms import Lambda +from PIL import Image import astar from classes import Field, Plant, Fertilizer, Player @@ -14,8 +23,80 @@ from screen import SCREEN Ucelu = False SCREENX = 500 SCREENY = 500 +device = torch.device('cpu') +model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device) +model1.load_state_dict(torch.load("./NN/trained")) pygame.display.set_caption('Inteligentny Traktor') +plants = [[], [], []] +plants[0].append(Image.open("NN/w1.png")) +plants[0].append(Image.open("NN/w2.png")) +plants[0].append(Image.open("NN/w3.png")) +plants[1].append(Image.open("NN/c1.png")) +plants[1].append(Image.open("NN/c2.png")) +plants[1].append(Image.open("NN/c3.png")) +plants[2].append(Image.open("NN/ca1.png")) +plants[2].append(Image.open("NN/ca2.png")) +plants[2].append(Image.open("NN/ca3.png")) +b = [Image.open("NN/b1.png").convert('RGBA'), Image.open("NN/b2.png").convert('RGBA'), Image.open("NN/b3.png").convert('RGBA')] + +def generate(water, fertilizer, plantf): + new_im = None + if water == 1: + new_im = Image.new('RGB', (100, 100), + (160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10))) + tmp = plants[plantf][random.randint(0, 2)].resize( + (25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp) + if fertilizer: + tmp = b[random.randint(0, 2)].resize( + (20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp) + else: + if fertilizer: + new_im = Image.new('RGB', (100, 100), + ( + 50 + random.randint(-10, 10), 25 + random.randint(-10, 10), + 0 + random.randint(-10, 10))) + tmp = plants[plantf][random.randint(0, 2)].resize( + (25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp) + tmp = b[random.randint(0, 2)].resize( + (20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp) + + else: + if random.randint(0, 1) == 1: + new_im = Image.new('RGB', (100, 100), + (50 + random.randint(-10, 10), 25 + random.randint(-10, 10), + 0 + random.randint(-10, 10))) + else: + new_im = Image.new('RGB', (100, 100), + (160 + random.randint(-10, 10), 80 + random.randint(-10, 10), + 40 + random.randint(-10, 10))) + if random.randint(0, 1) == 1: # big + tmp = plants[plantf][random.randint(0, 2)].resize( + (75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359)) + new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp) + else: + tmp = plants[plantf][random.randint(0, 2)].resize( + (random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359)) + datas = tmp.getdata() + + new_image_data = [] + for item in datas: + # change all white (also shades of whites) pixels to yellow + if item[0] in list(range(190, 256)): + new_image_data.append( + (random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10))) + else: + new_image_data.append(item) + + # update image data + tmp.putdata(new_image_data) + new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp) + + return new_im # COLORS WHITE = (255, 255, 255) @@ -399,6 +480,35 @@ def eventHandler(kbdObj,mouseObj): pygame.time.wait(DELAY) # If Key_x is pressed, spawn tree + if kbdObj[pygame.K_t]: + w = random.randint(0, 1) + f=random.randint(0, 1) + print(w) + print(f) + img = generate(w,f,random.randint(0,2)) + img.save('./test/00/test.png') + + data_transform = transforms.Compose([ + transforms.Resize(size=(100, 100)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + Lambda(lambda x: x.flatten()) + ]) + datasets.ImageNet + train_data = datasets.ImageFolder(root="./test", + transform=data_transform, + target_transform=None) + model1.eval() + res = model1(train_data[0][0]) + if res[0] ==res.max(): + print("0 0") + if res[1] ==res.max(): + print("0 1") + if res[2] ==res.max(): + print("1 0") + if res[3] ==res.max(): + print("1 1") + #img.show() if kbdObj[pygame.K_x]: obs = Obstacle(mouseObj) obstacleObjects[obstacles] = obs