add networks
This commit is contained in:
parent
16b056047c
commit
45aaf233f1
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_19" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (pythonProject)" project-jdk-type="Python SDK" />
|
||||
</project>
|
84
NN/Generator.py
Normal file
84
NN/Generator.py
Normal file
@ -0,0 +1,84 @@
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
plants = [[], [], []]
|
||||
plants[0].append(Image.open("w1.png"))
|
||||
plants[0].append(Image.open("w2.png"))
|
||||
plants[0].append(Image.open("w3.png"))
|
||||
plants[1].append(Image.open("c1.png"))
|
||||
plants[1].append(Image.open("c2.png"))
|
||||
plants[1].append(Image.open("c3.png"))
|
||||
plants[2].append(Image.open("ca1.png"))
|
||||
plants[2].append(Image.open("ca2.png"))
|
||||
plants[2].append(Image.open("ca3.png"))
|
||||
b = [Image.open("b1.png").convert('RGBA'), Image.open("b2.png").convert('RGBA'), Image.open("b3.png").convert('RGBA')]
|
||||
|
||||
|
||||
def generate(water, fertilizer, plantf):
|
||||
new_im = None
|
||||
if water == 1:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||
if fertilizer:
|
||||
tmp = b[random.randint(0, 2)].resize(
|
||||
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||
else:
|
||||
if fertilizer:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(
|
||||
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||
0 + random.randint(-10, 10)))
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||
tmp = b[random.randint(0, 2)].resize(
|
||||
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||
|
||||
else:
|
||||
if random.randint(0, 1) == 1:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||
0 + random.randint(-10, 10)))
|
||||
else:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
|
||||
40 + random.randint(-10, 10)))
|
||||
if random.randint(0, 1) == 1: # big
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
|
||||
else:
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
|
||||
datas = tmp.getdata()
|
||||
|
||||
new_image_data = []
|
||||
for item in datas:
|
||||
# change all white (also shades of whites) pixels to yellow
|
||||
if item[0] in list(range(190, 256)):
|
||||
new_image_data.append(
|
||||
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
|
||||
else:
|
||||
new_image_data.append(item)
|
||||
|
||||
# update image data
|
||||
tmp.putdata(new_image_data)
|
||||
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
|
||||
|
||||
return new_im
|
||||
|
||||
|
||||
for x in range(0, 1000):
|
||||
generate(0, 0, random.randint(0, 2)).save('datasets/00/' + str(x) + '.png')
|
||||
for x in range(0, 1000):
|
||||
generate(1, 0, random.randint(0, 2)).save('datasets/10/' + str(x) + '.png')
|
||||
for x in range(0, 1000):
|
||||
generate(0, 1, random.randint(0, 2)).save('datasets/01/' + str(x) + '.png')
|
||||
for x in range(0, 1000):
|
||||
generate(1, 1, random.randint(0, 2)).save('datasets/11/' + str(x) + '.png')
|
||||
|
47
NN/trainer.py
Normal file
47
NN/trainer.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.transforms import Lambda
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
def train(model, dataset, n_iter=100, batch_size=2560000):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = nn.NLLLoss()
|
||||
dl = DataLoader(dataset, batch_size=batch_size)
|
||||
model.train()
|
||||
for epoch in range(n_iter):
|
||||
for images, targets in dl:
|
||||
optimizer.zero_grad()
|
||||
out = model(images.to(device))
|
||||
loss = criterion(out, targets.to(device))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if epoch % 10 == 0:
|
||||
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
||||
|
||||
|
||||
image_path_list = list(pathlib.Path('./').glob("*/*/*.png"))
|
||||
|
||||
random_image_path = random.choice(image_path_list)
|
||||
data_transform = transforms.Compose([
|
||||
transforms.Resize(size=(100, 100)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ToTensor(),
|
||||
Lambda(lambda x: x.flatten())
|
||||
])
|
||||
|
||||
train_data = datasets.ImageFolder(root="./datasets",
|
||||
transform=data_transform,
|
||||
target_transform=None)
|
||||
|
||||
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
|
||||
model1.load_state_dict(torch.load("./trained"))
|
||||
train(model1,train_data)
|
||||
|
||||
torch.save(model1.state_dict(), "./trained")
|
110
main.py
110
main.py
@ -1,4 +1,13 @@
|
||||
import pygame
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.transforms import Lambda
|
||||
from PIL import Image
|
||||
|
||||
import astar
|
||||
from classes import Field, Plant, Fertilizer, Player
|
||||
@ -14,8 +23,80 @@ from screen import SCREEN
|
||||
Ucelu = False
|
||||
SCREENX = 500
|
||||
SCREENY = 500
|
||||
device = torch.device('cpu')
|
||||
model1=nn.Sequential(nn.Linear(30000, 10000),nn.ReLU(),nn.Linear(10000,10000),nn.ReLU(),nn.Linear(10000,10000),nn.Linear(10000,4),nn.LogSoftmax(dim=-1)).to(device)
|
||||
model1.load_state_dict(torch.load("./NN/trained"))
|
||||
|
||||
pygame.display.set_caption('Inteligentny Traktor')
|
||||
plants = [[], [], []]
|
||||
plants[0].append(Image.open("NN/w1.png"))
|
||||
plants[0].append(Image.open("NN/w2.png"))
|
||||
plants[0].append(Image.open("NN/w3.png"))
|
||||
plants[1].append(Image.open("NN/c1.png"))
|
||||
plants[1].append(Image.open("NN/c2.png"))
|
||||
plants[1].append(Image.open("NN/c3.png"))
|
||||
plants[2].append(Image.open("NN/ca1.png"))
|
||||
plants[2].append(Image.open("NN/ca2.png"))
|
||||
plants[2].append(Image.open("NN/ca3.png"))
|
||||
b = [Image.open("NN/b1.png").convert('RGBA'), Image.open("NN/b2.png").convert('RGBA'), Image.open("NN/b3.png").convert('RGBA')]
|
||||
|
||||
def generate(water, fertilizer, plantf):
|
||||
new_im = None
|
||||
if water == 1:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10), 40 + random.randint(-10, 10)))
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||
if fertilizer:
|
||||
tmp = b[random.randint(0, 2)].resize(
|
||||
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||
else:
|
||||
if fertilizer:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(
|
||||
50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||
0 + random.randint(-10, 10)))
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(25 + random.randint(-10, 25), 25 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 50), random.randint(0, 50)), tmp)
|
||||
tmp = b[random.randint(0, 2)].resize(
|
||||
(20 + random.randint(0, 25), 20 + random.randint(0, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(25, 75), random.randint(25, 75)), tmp)
|
||||
|
||||
else:
|
||||
if random.randint(0, 1) == 1:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(50 + random.randint(-10, 10), 25 + random.randint(-10, 10),
|
||||
0 + random.randint(-10, 10)))
|
||||
else:
|
||||
new_im = Image.new('RGB', (100, 100),
|
||||
(160 + random.randint(-10, 10), 80 + random.randint(-10, 10),
|
||||
40 + random.randint(-10, 10)))
|
||||
if random.randint(0, 1) == 1: # big
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(75 + random.randint(-10, 25), 75 + random.randint(-10, 25))).rotate(random.randint(0, 359))
|
||||
new_im.paste(tmp, (random.randint(0, 15), random.randint(0, 15)), tmp)
|
||||
else:
|
||||
tmp = plants[plantf][random.randint(0, 2)].resize(
|
||||
(random.randint(10, 80), random.randint(10, 80))).rotate(random.randint(0, 359))
|
||||
datas = tmp.getdata()
|
||||
|
||||
new_image_data = []
|
||||
for item in datas:
|
||||
# change all white (also shades of whites) pixels to yellow
|
||||
if item[0] in list(range(190, 256)):
|
||||
new_image_data.append(
|
||||
(random.randint(0, 10), 255 + random.randint(-150, 0), random.randint(0, 10)))
|
||||
else:
|
||||
new_image_data.append(item)
|
||||
|
||||
# update image data
|
||||
tmp.putdata(new_image_data)
|
||||
new_im.paste(tmp, (random.randint(0, 30), random.randint(0, 30)), tmp)
|
||||
|
||||
return new_im
|
||||
|
||||
# COLORS
|
||||
WHITE = (255, 255, 255)
|
||||
@ -399,6 +480,35 @@ def eventHandler(kbdObj,mouseObj):
|
||||
pygame.time.wait(DELAY)
|
||||
|
||||
# If Key_x is pressed, spawn tree
|
||||
if kbdObj[pygame.K_t]:
|
||||
w = random.randint(0, 1)
|
||||
f=random.randint(0, 1)
|
||||
print(w)
|
||||
print(f)
|
||||
img = generate(w,f,random.randint(0,2))
|
||||
img.save('./test/00/test.png')
|
||||
|
||||
data_transform = transforms.Compose([
|
||||
transforms.Resize(size=(100, 100)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ToTensor(),
|
||||
Lambda(lambda x: x.flatten())
|
||||
])
|
||||
datasets.ImageNet
|
||||
train_data = datasets.ImageFolder(root="./test",
|
||||
transform=data_transform,
|
||||
target_transform=None)
|
||||
model1.eval()
|
||||
res = model1(train_data[0][0])
|
||||
if res[0] ==res.max():
|
||||
print("0 0")
|
||||
if res[1] ==res.max():
|
||||
print("0 1")
|
||||
if res[2] ==res.max():
|
||||
print("1 0")
|
||||
if res[3] ==res.max():
|
||||
print("1 1")
|
||||
#img.show()
|
||||
if kbdObj[pygame.K_x]:
|
||||
obs = Obstacle(mouseObj)
|
||||
obstacleObjects[obstacles] = obs
|
||||
|
Loading…
Reference in New Issue
Block a user