Update 'src/main.py'
This commit is contained in:
parent
580211e6e9
commit
d66003c818
115
src/main.py
115
src/main.py
@ -1,7 +1,11 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import np as np
|
||||||
import pygame
|
import pygame
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import struct;print( 8 * struct.calcsize("P"))
|
||||||
|
from mlxtend.data import loadlocal_mnist
|
||||||
from src.a_star import a_star_search, WAREHOUSE_MAP
|
from src.a_star import a_star_search, WAREHOUSE_MAP
|
||||||
from src.training_data.scripts.gen_data_id3 import generate_size, generate_weight, generate_price, generate_flammability, generate_explosiveness
|
from src.training_data.scripts.gen_data_id3 import generate_size, generate_weight, generate_price, generate_flammability, generate_explosiveness
|
||||||
from src.tree import predict_result, load_tree_from_structure
|
from src.tree import predict_result, load_tree_from_structure
|
||||||
@ -21,6 +25,115 @@ DOCK = pygame.transform.scale(pygame.image.load('img/dock_left.png'), (75, 75))
|
|||||||
TRASH = pygame.transform.scale(pygame.image.load('img/trash.png'), (75, 75))
|
TRASH = pygame.transform.scale(pygame.image.load('img/trash.png'), (75, 75))
|
||||||
get_pix_from_position = {}
|
get_pix_from_position = {}
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
#train_images = np.float32(Image.open('img/background.png').resize( ( 64 , 64 ) ).convert( 'L' )),
|
||||||
|
#np.float32(Image.open('img/stain.png').resize( ( 64 , 64 ) ).convert( 'L' )),
|
||||||
|
#np.float32(Image.open('img/shelf.png').resize( ( 64 , 64 ) ).convert( 'L' ))
|
||||||
|
|
||||||
|
#print(train_images)
|
||||||
|
#train_images /= 255.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
train_images, train_labels = loadlocal_mnist(
|
||||||
|
images_path='data/train-images.idx3-ubyte',
|
||||||
|
labels_path='data/train-labels.idx1-ubyte')
|
||||||
|
#test_images, test_labels = loadlocal_mnist(
|
||||||
|
# images_path='data/t10k-images.idx3-ubyte',
|
||||||
|
# labels_path='data/t10k-labels.idx1-ubyte')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.pyplot import imshow
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
def plotdigit(image):
|
||||||
|
img = np.reshape(image, (-1, 28))
|
||||||
|
imshow(image, cmap='Greys', vmin=0, vmax=255)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
train_images = train_images / 255
|
||||||
|
#test_images = test_images / 255
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
train_images = [torch.tensor(image, dtype=torch.float32) for image in train_images]
|
||||||
|
train_labels = [torch.tensor(label, dtype=torch.long) for label in train_labels]
|
||||||
|
#test_images = [torch.tensor(image, dtype=torch.float32) for image in test_images]
|
||||||
|
#test_labels = [torch.tensor(label, dtype=torch.long) for label in test_labels]
|
||||||
|
|
||||||
|
input_dim = 28*28
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, output_dim),
|
||||||
|
nn.LogSoftmax()
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(model, n_iter):
|
||||||
|
criterion = nn.NLLLoss()
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for epoch in range(n_iter):
|
||||||
|
for image, label in zip(train_images, train_labels):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
output = model(image)
|
||||||
|
loss = criterion(output.unsqueeze(0), label.unsqueeze(0))
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print(f'epoch: {epoch:03}')
|
||||||
|
|
||||||
|
train(model, 1)
|
||||||
|
|
||||||
|
Sh = 200
|
||||||
|
St = 10
|
||||||
|
WAREHOUSE_MAP = [
|
||||||
|
[0, 0, 0, St, St, 0, 0, 0, 0],
|
||||||
|
[0, St, 0, St, 0, St, 0, 0, 0],
|
||||||
|
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
|
||||||
|
[0, 0, St, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, St, 0, 0, Sh, Sh, Sh, Sh],
|
||||||
|
[0, 0, St, St, 0, 0, 0, 0, 0],
|
||||||
|
[St, 0, 0, 0, St, Sh, Sh, Sh, Sh],
|
||||||
|
[0, 0, St, St, St, 0, 0, 0, 0],
|
||||||
|
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
|
||||||
|
]
|
||||||
|
|
||||||
|
AGENT_MAP=[]
|
||||||
|
test=[]
|
||||||
|
x=0
|
||||||
|
y=0
|
||||||
|
while x<9:
|
||||||
|
while y<9:
|
||||||
|
if WAREHOUSE_MAP[x][y]==0:
|
||||||
|
ai = model(train_images[1]).argmax()
|
||||||
|
elif WAREHOUSE_MAP[x][y]==St:
|
||||||
|
ai = model(train_images[0]).argmax()
|
||||||
|
elif WAREHOUSE_MAP[x][y]==Sh:
|
||||||
|
ai = model(train_images[2]).argmax()
|
||||||
|
|
||||||
|
print(ai)
|
||||||
|
if str(ai) == 'tensor(5)':
|
||||||
|
test.append(10)
|
||||||
|
elif str(ai) == 'tensor(4)':
|
||||||
|
test.append(200)
|
||||||
|
else:
|
||||||
|
test.append(0)
|
||||||
|
y = y + 1
|
||||||
|
AGENT_MAP.append(test)
|
||||||
|
test=[]
|
||||||
|
x=x+1
|
||||||
|
y=0
|
||||||
|
for a in AGENT_MAP:
|
||||||
|
print(a)
|
||||||
|
|
||||||
def get_position_from_pix(pix_pos):
|
def get_position_from_pix(pix_pos):
|
||||||
for key, value in get_pix_from_position.items():
|
for key, value in get_pix_from_position.items():
|
||||||
|
Loading…
Reference in New Issue
Block a user