add networks

This commit is contained in:
Tomasz Sidoruk 2023-06-01 11:10:14 +02:00
parent 16b056047c
commit 45aaf233f1
4 changed files with 242 additions and 1 deletions

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" languageLevel="JDK_19" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (pythonProject)" project-jdk-type="Python SDK" />
</project>

84
NN/Generator.py Normal file
View File

@ -0,0 +1,84 @@
from PIL import Image
import random
plants = [[], [], []]
plants[0].append(Image.open("w1.png"))
plants[0].append(Image.open("w2.png"))
plants[0].append(Image.open("w3.png"))
plants[1].append(Image.open("c1.png"))
plants[1].append(Image.open("c2.png"))
plants[1].append(Image.open("c3.png"))
plants[2].append(Image.open("ca1.png"))
plants[2].append(Image.open("ca2.png"))
plants[2].append(Image.open("ca3.png"))
b = [Image.open("b1.png").convert('RGBA'), Image.open("b2.png").convert('RGBA'), Image.open("b3.png").convert('RGBA')]
def generate(water, fertilizer, plantf):
new_im = None
if water == 1:
new_im = Image.new('RGB', (100, 100),
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
tmp = plants[plantf][random.randint(0, 2)].resize(
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
if fertilizer:
tmp = b[random.randint(0, 2)].resize(
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
else:
if fertilizer:
new_im = Image.new('RGB', (100, 100),
(
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
0 + random.randint(-10, 10)))
tmp = plants[plantf][random.randint(0, 2)].resize(
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
tmp = b[random.randint(0, 2)].resize(
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
else:
if random.randint(0, 1) == 1:
new_im = Image.new('RGB', (100, 100),
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
0 + random.randint(-10, 10)))
else:
new_im = Image.new('RGB', (100, 100),
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
40 + random.randint(-10, 10)))
if random.randint(0, 1) == 1: # big
tmp = plants[plantf][random.randint(0, 2)].resize(
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
else:
tmp = plants[plantf][random.randint(0, 2)].resize(
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
datas = tmp.getdata()
new_image_data = []
for item in datas:
# change all white (also shades of whites) pixels to yellow
if item[0] in list(range(190, 256)):
new_image_data.append(
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
else:
new_image_data.append(item)
# update image data
tmp.putdata(new_image_data)
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
return new_im
for x in range(0, 1000):
generate(0, 0, random.randint(0, 2)).save('datasets/00/' + str(x) + '.png')
for x in range(0, 1000):
generate(1, 0, random.randint(0, 2)).save('datasets/10/' + str(x) + '.png')
for x in range(0, 1000):
generate(0, 1, random.randint(0, 2)).save('datasets/01/' + str(x) + '.png')
for x in range(0, 1000):
generate(1, 1, random.randint(0, 2)).save('datasets/11/' + str(x) + '.png')

47
NN/trainer.py Normal file
View File

@ -0,0 +1,47 @@
import pathlib
import random
import torch
from PIL.Image import Image
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Lambda
device = torch.device('cuda')
def train(model, dataset, n_iter=100, batch_size=2560000):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
dl = DataLoader(dataset, batch_size=batch_size)
model.train()
for epoch in range(n_iter):
for images, targets in dl:
optimizer.zero_grad()
out = model(images.to(device))
loss = criterion(out, targets.to(device))
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('epoch: %3d loss: %.4f' % (epoch, loss))
image_path_list = list(pathlib.Path('./').glob("*/*/*.png"))
random_image_path = random.choice(image_path_list)
data_transform = transforms.Compose([
transforms.Resize(size=(100, 100)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
Lambda(lambda x: x.flatten())
])
train_data = datasets.ImageFolder(root="./datasets",
transform=data_transform,
target_transform=None)
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
model1.load_state_dict(torch.load("./trained"))
train(model1,train_data)
torch.save(model1.state_dict(), "./trained")

110
main.py
View File

@ -1,4 +1,13 @@
import pygame
import pathlib
import random
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Lambda
from PIL import Image
import astar
from classes import Field, Plant, Fertilizer, Player
@ -14,8 +23,80 @@ from screen import SCREEN
Ucelu = False
SCREENX = 500
SCREENY = 500
device = torch.device('cpu')
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
model1.load_state_dict(torch.load("./NN/trained"))
pygame.display.set_caption('Inteligentny Traktor')
plants = [[], [], []]
plants[0].append(Image.open("NN/w1.png"))
plants[0].append(Image.open("NN/w2.png"))
plants[0].append(Image.open("NN/w3.png"))
plants[1].append(Image.open("NN/c1.png"))
plants[1].append(Image.open("NN/c2.png"))
plants[1].append(Image.open("NN/c3.png"))
plants[2].append(Image.open("NN/ca1.png"))
plants[2].append(Image.open("NN/ca2.png"))
plants[2].append(Image.open("NN/ca3.png"))
b = [Image.open("NN/b1.png").convert('RGBA'), Image.open("NN/b2.png").convert('RGBA'), Image.open("NN/b3.png").convert('RGBA')]
def generate(water, fertilizer, plantf):
new_im = None
if water == 1:
new_im = Image.new('RGB', (100, 100),
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
tmp = plants[plantf][random.randint(0, 2)].resize(
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
if fertilizer:
tmp = b[random.randint(0, 2)].resize(
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
else:
if fertilizer:
new_im = Image.new('RGB', (100, 100),
(
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
0 + random.randint(-10, 10)))
tmp = plants[plantf][random.randint(0, 2)].resize(
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
tmp = b[random.randint(0, 2)].resize(
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
else:
if random.randint(0, 1) == 1:
new_im = Image.new('RGB', (100, 100),
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
0 + random.randint(-10, 10)))
else:
new_im = Image.new('RGB', (100, 100),
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
40 + random.randint(-10, 10)))
if random.randint(0, 1) == 1: # big
tmp = plants[plantf][random.randint(0, 2)].resize(
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
else:
tmp = plants[plantf][random.randint(0, 2)].resize(
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
datas = tmp.getdata()
new_image_data = []
for item in datas:
# change all white (also shades of whites) pixels to yellow
if item[0] in list(range(190, 256)):
new_image_data.append(
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
else:
new_image_data.append(item)
# update image data
tmp.putdata(new_image_data)
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
return new_im
# COLORS
WHITE = (255, 255, 255)
@ -399,6 +480,35 @@ def eventHandler(kbdObj,mouseObj):
pygame.time.wait(DELAY)
# If Key_x is pressed, spawn tree
if kbdObj[pygame.K_t]:
w = random.randint(0, 1)
f=random.randint(0, 1)
print(w)
print(f)
img = generate(w,f,random.randint(0,2))
img.save('./test/00/test.png')
data_transform = transforms.Compose([
transforms.Resize(size=(100, 100)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
Lambda(lambda x: x.flatten())
])
datasets.ImageNet
train_data = datasets.ImageFolder(root="./test",
transform=data_transform,
target_transform=None)
model1.eval()
res = model1(train_data[0][0])
if res[0] ==res.max():
print("0 0")
if res[1] ==res.max():
print("0 1")
if res[2] ==res.max():
print("1 0")
if res[3] ==res.max():
print("1 1")
#img.show()
if kbdObj[pygame.K_x]:
obs = Obstacle(mouseObj)
obstacleObjects[obstacles] = obs