Update 'src/main.py'

This commit is contained in:
Bogdan Panov 2021-06-05 20:39:47 +02:00
parent 580211e6e9
commit d66003c818

View File

@ -1,7 +1,11 @@
import os
import np as np
import pygame import pygame
import random import random
import time import time
import struct;print( 8 * struct.calcsize("P"))
from mlxtend.data import loadlocal_mnist
from src.a_star import a_star_search, WAREHOUSE_MAP from src.a_star import a_star_search, WAREHOUSE_MAP
from src.training_data.scripts.gen_data_id3 import generate_size, generate_weight, generate_price, generate_flammability, generate_explosiveness from src.training_data.scripts.gen_data_id3 import generate_size, generate_weight, generate_price, generate_flammability, generate_explosiveness
from src.tree import predict_result, load_tree_from_structure from src.tree import predict_result, load_tree_from_structure
@ -21,6 +25,115 @@ DOCK = pygame.transform.scale(pygame.image.load('img/dock_left.png'), (75, 75))
TRASH = pygame.transform.scale(pygame.image.load('img/trash.png'), (75, 75)) TRASH = pygame.transform.scale(pygame.image.load('img/trash.png'), (75, 75))
get_pix_from_position = {} get_pix_from_position = {}
from PIL import Image
#train_images = np.float32(Image.open('img/background.png').resize( ( 64 , 64 ) ).convert( 'L' )),
#np.float32(Image.open('img/stain.png').resize( ( 64 , 64 ) ).convert( 'L' )),
#np.float32(Image.open('img/shelf.png').resize( ( 64 , 64 ) ).convert( 'L' ))
#print(train_images)
#train_images /= 255.
train_images, train_labels = loadlocal_mnist(
images_path='data/train-images.idx3-ubyte',
labels_path='data/train-labels.idx1-ubyte')
#test_images, test_labels = loadlocal_mnist(
# images_path='data/t10k-images.idx3-ubyte',
# labels_path='data/t10k-labels.idx1-ubyte')
import numpy as np
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
def plotdigit(image):
img = np.reshape(image, (-1, 28))
imshow(image, cmap='Greys', vmin=0, vmax=255)
train_images = train_images / 255
#test_images = test_images / 255
import torch
from torch import nn
from torch import optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_images = [torch.tensor(image, dtype=torch.float32) for image in train_images]
train_labels = [torch.tensor(label, dtype=torch.long) for label in train_labels]
#test_images = [torch.tensor(image, dtype=torch.float32) for image in test_images]
#test_labels = [torch.tensor(label, dtype=torch.long) for label in test_labels]
input_dim = 28*28
output_dim = 10
model = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.LogSoftmax()
)
def train(model, n_iter):
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in range(n_iter):
for image, label in zip(train_images, train_labels):
optimizer.zero_grad()
output = model(image)
loss = criterion(output.unsqueeze(0), label.unsqueeze(0))
loss.backward()
optimizer.step()
print(f'epoch: {epoch:03}')
train(model, 1)
Sh = 200
St = 10
WAREHOUSE_MAP = [
[0, 0, 0, St, St, 0, 0, 0, 0],
[0, St, 0, St, 0, St, 0, 0, 0],
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
[0, 0, St, 0, 0, 0, 0, 0, 0],
[0, 0, St, 0, 0, Sh, Sh, Sh, Sh],
[0, 0, St, St, 0, 0, 0, 0, 0],
[St, 0, 0, 0, St, Sh, Sh, Sh, Sh],
[0, 0, St, St, St, 0, 0, 0, 0],
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
]
AGENT_MAP=[]
test=[]
x=0
y=0
while x<9:
while y<9:
if WAREHOUSE_MAP[x][y]==0:
ai = model(train_images[1]).argmax()
elif WAREHOUSE_MAP[x][y]==St:
ai = model(train_images[0]).argmax()
elif WAREHOUSE_MAP[x][y]==Sh:
ai = model(train_images[2]).argmax()
print(ai)
if str(ai) == 'tensor(5)':
test.append(10)
elif str(ai) == 'tensor(4)':
test.append(200)
else:
test.append(0)
y = y + 1
AGENT_MAP.append(test)
test=[]
x=x+1
y=0
for a in AGENT_MAP:
print(a)
def get_position_from_pix(pix_pos): def get_position_from_pix(pix_pos):
for key, value in get_pix_from_position.items(): for key, value in get_pix_from_position.items():