Update 'src/main.py'
This commit is contained in:
parent
580211e6e9
commit
d66003c818
115
src/main.py
115
src/main.py
@ -1,7 +1,11 @@
|
||||
import os
|
||||
|
||||
import np as np
|
||||
import pygame
|
||||
import random
|
||||
import time
|
||||
|
||||
import struct;print( 8 * struct.calcsize("P"))
|
||||
from mlxtend.data import loadlocal_mnist
|
||||
from src.a_star import a_star_search, WAREHOUSE_MAP
|
||||
from src.training_data.scripts.gen_data_id3 import generate_size, generate_weight, generate_price, generate_flammability, generate_explosiveness
|
||||
from src.tree import predict_result, load_tree_from_structure
|
||||
@ -21,6 +25,115 @@ DOCK = pygame.transform.scale(pygame.image.load('img/dock_left.png'), (75, 75))
|
||||
TRASH = pygame.transform.scale(pygame.image.load('img/trash.png'), (75, 75))
|
||||
get_pix_from_position = {}
|
||||
|
||||
from PIL import Image
|
||||
|
||||
#train_images = np.float32(Image.open('img/background.png').resize( ( 64 , 64 ) ).convert( 'L' )),
|
||||
#np.float32(Image.open('img/stain.png').resize( ( 64 , 64 ) ).convert( 'L' )),
|
||||
#np.float32(Image.open('img/shelf.png').resize( ( 64 , 64 ) ).convert( 'L' ))
|
||||
|
||||
#print(train_images)
|
||||
#train_images /= 255.
|
||||
|
||||
|
||||
|
||||
train_images, train_labels = loadlocal_mnist(
|
||||
images_path='data/train-images.idx3-ubyte',
|
||||
labels_path='data/train-labels.idx1-ubyte')
|
||||
#test_images, test_labels = loadlocal_mnist(
|
||||
# images_path='data/t10k-images.idx3-ubyte',
|
||||
# labels_path='data/t10k-labels.idx1-ubyte')
|
||||
|
||||
import numpy as np
|
||||
from matplotlib.pyplot import imshow
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
def plotdigit(image):
|
||||
img = np.reshape(image, (-1, 28))
|
||||
imshow(image, cmap='Greys', vmin=0, vmax=255)
|
||||
|
||||
|
||||
|
||||
train_images = train_images / 255
|
||||
#test_images = test_images / 255
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
train_images = [torch.tensor(image, dtype=torch.float32) for image in train_images]
|
||||
train_labels = [torch.tensor(label, dtype=torch.long) for label in train_labels]
|
||||
#test_images = [torch.tensor(image, dtype=torch.float32) for image in test_images]
|
||||
#test_labels = [torch.tensor(label, dtype=torch.long) for label in test_labels]
|
||||
|
||||
input_dim = 28*28
|
||||
output_dim = 10
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim),
|
||||
nn.LogSoftmax()
|
||||
)
|
||||
|
||||
def train(model, n_iter):
|
||||
criterion = nn.NLLLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
|
||||
for epoch in range(n_iter):
|
||||
for image, label in zip(train_images, train_labels):
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(image)
|
||||
loss = criterion(output.unsqueeze(0), label.unsqueeze(0))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f'epoch: {epoch:03}')
|
||||
|
||||
train(model, 1)
|
||||
|
||||
Sh = 200
|
||||
St = 10
|
||||
WAREHOUSE_MAP = [
|
||||
[0, 0, 0, St, St, 0, 0, 0, 0],
|
||||
[0, St, 0, St, 0, St, 0, 0, 0],
|
||||
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
|
||||
[0, 0, St, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, St, 0, 0, Sh, Sh, Sh, Sh],
|
||||
[0, 0, St, St, 0, 0, 0, 0, 0],
|
||||
[St, 0, 0, 0, St, Sh, Sh, Sh, Sh],
|
||||
[0, 0, St, St, St, 0, 0, 0, 0],
|
||||
[0, St, 0, 0, 0, Sh, Sh, Sh, Sh],
|
||||
]
|
||||
|
||||
AGENT_MAP=[]
|
||||
test=[]
|
||||
x=0
|
||||
y=0
|
||||
while x<9:
|
||||
while y<9:
|
||||
if WAREHOUSE_MAP[x][y]==0:
|
||||
ai = model(train_images[1]).argmax()
|
||||
elif WAREHOUSE_MAP[x][y]==St:
|
||||
ai = model(train_images[0]).argmax()
|
||||
elif WAREHOUSE_MAP[x][y]==Sh:
|
||||
ai = model(train_images[2]).argmax()
|
||||
|
||||
print(ai)
|
||||
if str(ai) == 'tensor(5)':
|
||||
test.append(10)
|
||||
elif str(ai) == 'tensor(4)':
|
||||
test.append(200)
|
||||
else:
|
||||
test.append(0)
|
||||
y = y + 1
|
||||
AGENT_MAP.append(test)
|
||||
test=[]
|
||||
x=x+1
|
||||
y=0
|
||||
for a in AGENT_MAP:
|
||||
print(a)
|
||||
|
||||
def get_position_from_pix(pix_pos):
|
||||
for key, value in get_pix_from_position.items():
|
||||
|
Loading…
Reference in New Issue
Block a user