add networks
This commit is contained in:
parent
16b056047c
commit
45aaf233f1
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_19" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (pythonProject)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
84
NN/Generator.py
Normal file
84
NN/Generator.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
|
||||||
|
plants = [[], [], []]
|
||||||
|
plants[0].append(Image.open("w1.png"))
|
||||||
|
plants[0].append(Image.open("w2.png"))
|
||||||
|
plants[0].append(Image.open("w3.png"))
|
||||||
|
plants[1].append(Image.open("c1.png"))
|
||||||
|
plants[1].append(Image.open("c2.png"))
|
||||||
|
plants[1].append(Image.open("c3.png"))
|
||||||
|
plants[2].append(Image.open("ca1.png"))
|
||||||
|
plants[2].append(Image.open("ca2.png"))
|
||||||
|
plants[2].append(Image.open("ca3.png"))
|
||||||
|
b = [Image.open("b1.png").convert('RGBA'), Image.open("b2.png").convert('RGBA'), Image.open("b3.png").convert('RGBA')]
|
||||||
|
|
||||||
|
|
||||||
|
def generate(water, fertilizer, plantf):
|
||||||
|
new_im = None
|
||||||
|
if water == 1:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||||
|
if fertilizer:
|
||||||
|
tmp = b[random.randint(0, 2)].resize(
|
||||||
|
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||||
|
else:
|
||||||
|
if fertilizer:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(
|
||||||
|
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||||
|
0 + random.randint(-10, 10)))
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||||
|
tmp = b[random.randint(0, 2)].resize(
|
||||||
|
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if random.randint(0, 1) == 1:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||||
|
0 + random.randint(-10, 10)))
|
||||||
|
else:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
|
||||||
|
40 + random.randint(-10, 10)))
|
||||||
|
if random.randint(0, 1) == 1: # big
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
|
||||||
|
else:
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
|
||||||
|
datas = tmp.getdata()
|
||||||
|
|
||||||
|
new_image_data = []
|
||||||
|
for item in datas:
|
||||||
|
# change all white (also shades of whites) pixels to yellow
|
||||||
|
if item[0] in list(range(190, 256)):
|
||||||
|
new_image_data.append(
|
||||||
|
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
|
||||||
|
else:
|
||||||
|
new_image_data.append(item)
|
||||||
|
|
||||||
|
# update image data
|
||||||
|
tmp.putdata(new_image_data)
|
||||||
|
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
|
||||||
|
|
||||||
|
return new_im
|
||||||
|
|
||||||
|
|
||||||
|
for x in range(0, 1000):
|
||||||
|
generate(0, 0, random.randint(0, 2)).save('datasets/00/' + str(x) + '.png')
|
||||||
|
for x in range(0, 1000):
|
||||||
|
generate(1, 0, random.randint(0, 2)).save('datasets/10/' + str(x) + '.png')
|
||||||
|
for x in range(0, 1000):
|
||||||
|
generate(0, 1, random.randint(0, 2)).save('datasets/01/' + str(x) + '.png')
|
||||||
|
for x in range(0, 1000):
|
||||||
|
generate(1, 1, random.randint(0, 2)).save('datasets/11/' + str(x) + '.png')
|
||||||
|
|
47
NN/trainer.py
Normal file
47
NN/trainer.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torchvision.transforms import Lambda
|
||||||
|
|
||||||
|
device = torch.device('cuda')
|
||||||
|
|
||||||
|
def train(model, dataset, n_iter=100, batch_size=2560000):
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
criterion = nn.NLLLoss()
|
||||||
|
dl = DataLoader(dataset, batch_size=batch_size)
|
||||||
|
model.train()
|
||||||
|
for epoch in range(n_iter):
|
||||||
|
for images, targets in dl:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
out = model(images.to(device))
|
||||||
|
loss = criterion(out, targets.to(device))
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
||||||
|
|
||||||
|
|
||||||
|
image_path_list = list(pathlib.Path('./').glob("*/*/*.png"))
|
||||||
|
|
||||||
|
random_image_path = random.choice(image_path_list)
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
transforms.Resize(size=(100, 100)),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
Lambda(lambda x: x.flatten())
|
||||||
|
])
|
||||||
|
|
||||||
|
train_data = datasets.ImageFolder(root="./datasets",
|
||||||
|
transform=data_transform,
|
||||||
|
target_transform=None)
|
||||||
|
|
||||||
|
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
|
||||||
|
model1.load_state_dict(torch.load("./trained"))
|
||||||
|
train(model1,train_data)
|
||||||
|
|
||||||
|
torch.save(model1.state_dict(), "./trained")
|
110
main.py
110
main.py
@ -1,4 +1,13 @@
|
|||||||
import pygame
|
import pygame
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torchvision.transforms import Lambda
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import astar
|
import astar
|
||||||
from classes import Field, Plant, Fertilizer, Player
|
from classes import Field, Plant, Fertilizer, Player
|
||||||
@ -14,8 +23,80 @@ from screen import SCREEN
|
|||||||
Ucelu = False
|
Ucelu = False
|
||||||
SCREENX = 500
|
SCREENX = 500
|
||||||
SCREENY = 500
|
SCREENY = 500
|
||||||
|
device = torch.device('cpu')
|
||||||
|
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
|
||||||
|
model1.load_state_dict(torch.load("./NN/trained"))
|
||||||
|
|
||||||
pygame.display.set_caption('Inteligentny Traktor')
|
pygame.display.set_caption('Inteligentny Traktor')
|
||||||
|
plants = [[], [], []]
|
||||||
|
plants[0].append(Image.open("NN/w1.png"))
|
||||||
|
plants[0].append(Image.open("NN/w2.png"))
|
||||||
|
plants[0].append(Image.open("NN/w3.png"))
|
||||||
|
plants[1].append(Image.open("NN/c1.png"))
|
||||||
|
plants[1].append(Image.open("NN/c2.png"))
|
||||||
|
plants[1].append(Image.open("NN/c3.png"))
|
||||||
|
plants[2].append(Image.open("NN/ca1.png"))
|
||||||
|
plants[2].append(Image.open("NN/ca2.png"))
|
||||||
|
plants[2].append(Image.open("NN/ca3.png"))
|
||||||
|
b = [Image.open("NN/b1.png").convert('RGBA'), Image.open("NN/b2.png").convert('RGBA'), Image.open("NN/b3.png").convert('RGBA')]
|
||||||
|
|
||||||
|
def generate(water, fertilizer, plantf):
|
||||||
|
new_im = None
|
||||||
|
if water == 1:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||||
|
if fertilizer:
|
||||||
|
tmp = b[random.randint(0, 2)].resize(
|
||||||
|
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||||
|
else:
|
||||||
|
if fertilizer:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(
|
||||||
|
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||||
|
0 + random.randint(-10, 10)))
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||||
|
tmp = b[random.randint(0, 2)].resize(
|
||||||
|
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if random.randint(0, 1) == 1:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||||
|
0 + random.randint(-10, 10)))
|
||||||
|
else:
|
||||||
|
new_im = Image.new('RGB', (100, 100),
|
||||||
|
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
|
||||||
|
40 + random.randint(-10, 10)))
|
||||||
|
if random.randint(0, 1) == 1: # big
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||||
|
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
|
||||||
|
else:
|
||||||
|
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||||
|
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
|
||||||
|
datas = tmp.getdata()
|
||||||
|
|
||||||
|
new_image_data = []
|
||||||
|
for item in datas:
|
||||||
|
# change all white (also shades of whites) pixels to yellow
|
||||||
|
if item[0] in list(range(190, 256)):
|
||||||
|
new_image_data.append(
|
||||||
|
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
|
||||||
|
else:
|
||||||
|
new_image_data.append(item)
|
||||||
|
|
||||||
|
# update image data
|
||||||
|
tmp.putdata(new_image_data)
|
||||||
|
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
|
||||||
|
|
||||||
|
return new_im
|
||||||
|
|
||||||
# COLORS
|
# COLORS
|
||||||
WHITE = (255, 255, 255)
|
WHITE = (255, 255, 255)
|
||||||
@ -399,6 +480,35 @@ def eventHandler(kbdObj,mouseObj):
|
|||||||
pygame.time.wait(DELAY)
|
pygame.time.wait(DELAY)
|
||||||
|
|
||||||
# If Key_x is pressed, spawn tree
|
# If Key_x is pressed, spawn tree
|
||||||
|
if kbdObj[pygame.K_t]:
|
||||||
|
w = random.randint(0, 1)
|
||||||
|
f=random.randint(0, 1)
|
||||||
|
print(w)
|
||||||
|
print(f)
|
||||||
|
img = generate(w,f,random.randint(0,2))
|
||||||
|
img.save('./test/00/test.png')
|
||||||
|
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
transforms.Resize(size=(100, 100)),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
Lambda(lambda x: x.flatten())
|
||||||
|
])
|
||||||
|
datasets.ImageNet
|
||||||
|
train_data = datasets.ImageFolder(root="./test",
|
||||||
|
transform=data_transform,
|
||||||
|
target_transform=None)
|
||||||
|
model1.eval()
|
||||||
|
res = model1(train_data[0][0])
|
||||||
|
if res[0] ==res.max():
|
||||||
|
print("0 0")
|
||||||
|
if res[1] ==res.max():
|
||||||
|
print("0 1")
|
||||||
|
if res[2] ==res.max():
|
||||||
|
print("1 0")
|
||||||
|
if res[3] ==res.max():
|
||||||
|
print("1 1")
|
||||||
|
#img.show()
|
||||||
if kbdObj[pygame.K_x]:
|
if kbdObj[pygame.K_x]:
|
||||||
obs = Obstacle(mouseObj)
|
obs = Obstacle(mouseObj)
|
||||||
obstacleObjects[obstacles] = obs
|
obstacleObjects[obstacles] = obs
|
||||||
|
Loading…
Reference in New Issue
Block a user