forked from s444519/Waiter_group
Individual Project #2; s442720
This commit is contained in:
parent
a73862b48b
commit
71dc3e81a2
489
main_training.py
Normal file
489
main_training.py
Normal file
@ -0,0 +1,489 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import os, sys, time, datetime, json, random
|
||||||
|
import numpy as np
|
||||||
|
from keras.models import Sequential
|
||||||
|
from keras.layers.core import Dense, Activation
|
||||||
|
from keras.optimizers import SGD , Adam, RMSprop
|
||||||
|
from keras.layers.advanced_activations import PReLU
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
visited_mark = 0.8 # Cells visited by the rat will be painted by gray 0.8
|
||||||
|
rat_mark = 0.5 # The current rat cell will be painteg by gray 0.5
|
||||||
|
LEFT = 0
|
||||||
|
UP = 1
|
||||||
|
RIGHT = 2
|
||||||
|
DOWN = 3
|
||||||
|
|
||||||
|
# Actions dictionary
|
||||||
|
actions_dict = {
|
||||||
|
LEFT: 'left',
|
||||||
|
UP: 'up',
|
||||||
|
RIGHT: 'right',
|
||||||
|
DOWN: 'down',
|
||||||
|
}
|
||||||
|
|
||||||
|
num_actions = len(actions_dict)
|
||||||
|
|
||||||
|
# Exploration factor
|
||||||
|
epsilon = 0.1
|
||||||
|
file_name_num = 1
|
||||||
|
win_targets = [(4, 4),(4, 9),(4, 14),(9, 4)]
|
||||||
|
|
||||||
|
class Qmaze(object):
|
||||||
|
def __init__(self, maze, rat=(12,12)):
|
||||||
|
global win_targets
|
||||||
|
self._maze = np.array(maze)
|
||||||
|
nrows, ncols = self._maze.shape
|
||||||
|
#self.target = (nrows-1, ncols-1) # target cell where the "cheese" is
|
||||||
|
self.target = win_targets[0]
|
||||||
|
self.free_cells = [(r,c) for r in range(nrows) for c in range(ncols) if self._maze[r,c] == 1.0]
|
||||||
|
self.free_cells.remove(win_targets[-1])
|
||||||
|
if self._maze[self.target] == 0.0:
|
||||||
|
raise Exception("Invalid maze: target cell cannot be blocked!")
|
||||||
|
if not rat in self.free_cells:
|
||||||
|
raise Exception("Invalid Rat Location: must sit on a free cell")
|
||||||
|
self.reset(rat)
|
||||||
|
|
||||||
|
def reset(self, rat):
|
||||||
|
global win_targets
|
||||||
|
self.rat = rat
|
||||||
|
self.maze = np.copy(self._maze)
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
row, col = rat
|
||||||
|
self.maze[row, col] = rat_mark
|
||||||
|
self.state = (row, col, 'start')
|
||||||
|
self.min_reward = -0.5 * self.maze.size
|
||||||
|
self.total_reward = 0
|
||||||
|
self.visited = list()
|
||||||
|
self.curr_win_targets = win_targets[:]
|
||||||
|
|
||||||
|
def update_state(self, action):
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
nrow, ncol, nmode = rat_row, rat_col, mode = self.state
|
||||||
|
|
||||||
|
if self.maze[rat_row, rat_col] > 0.0:
|
||||||
|
self.visited.append((rat_row, rat_col)) # mark visited cell
|
||||||
|
|
||||||
|
valid_actions = self.valid_actions()
|
||||||
|
|
||||||
|
if not valid_actions:
|
||||||
|
nmode = 'blocked'
|
||||||
|
elif action in valid_actions:
|
||||||
|
nmode = 'valid'
|
||||||
|
if action == LEFT:
|
||||||
|
ncol -= 1
|
||||||
|
elif action == UP:
|
||||||
|
nrow -= 1
|
||||||
|
if action == RIGHT:
|
||||||
|
ncol += 1
|
||||||
|
elif action == DOWN:
|
||||||
|
nrow += 1
|
||||||
|
else: # invalid action, no change in rat position
|
||||||
|
mode = 'invalid'
|
||||||
|
|
||||||
|
# new state
|
||||||
|
self.state = (nrow, ncol, nmode)
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
win_target_x, win_target_y = self.target
|
||||||
|
rat_row, rat_col, mode = self.state
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
if rat_row == win_target_x and rat_col == win_target_y:
|
||||||
|
return 1.0
|
||||||
|
if mode == 'blocked': # move to the block in the grid
|
||||||
|
return -1.0
|
||||||
|
if (rat_row, rat_col) in self.visited:
|
||||||
|
return -0.5 # default -0.25 -> -0.5
|
||||||
|
if mode == 'invalid':
|
||||||
|
return -0.75 # default -0.75 move to the boundary
|
||||||
|
if mode == 'valid': # default -0.04 -> -0.1
|
||||||
|
return -0.04
|
||||||
|
if (rat_row, rat_col) in self.curr_win_targets:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def act(self, action):
|
||||||
|
self.update_state(action)
|
||||||
|
reward = self.get_reward()
|
||||||
|
self.total_reward += reward
|
||||||
|
status = self.game_status()
|
||||||
|
envstate = self.observe()
|
||||||
|
return envstate, reward, status
|
||||||
|
|
||||||
|
def observe(self):
|
||||||
|
canvas = self.draw_env()
|
||||||
|
envstate = canvas.reshape((1, -1))
|
||||||
|
return envstate
|
||||||
|
|
||||||
|
def draw_env(self):
|
||||||
|
canvas = np.copy(self.maze)
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
# clear all visual marks
|
||||||
|
for r in range(nrows):
|
||||||
|
for c in range(ncols):
|
||||||
|
if canvas[r,c] > 0.0:
|
||||||
|
canvas[r,c] = 1.0
|
||||||
|
# draw the rat
|
||||||
|
row, col, valid = self.state
|
||||||
|
canvas[row, col] = rat_mark
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
def game_status(self):
|
||||||
|
if self.total_reward < self.min_reward:
|
||||||
|
return 'lose'
|
||||||
|
rat_row, rat_col, mode = self.state
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
|
||||||
|
curPos = (rat_row, rat_col)
|
||||||
|
|
||||||
|
if curPos in self.curr_win_targets:
|
||||||
|
self.curr_win_targets.remove(curPos)
|
||||||
|
if len(self.curr_win_targets) == 0:
|
||||||
|
return 'win'
|
||||||
|
else:
|
||||||
|
self.target = self.curr_win_targets[0]
|
||||||
|
|
||||||
|
return 'not_over'
|
||||||
|
|
||||||
|
def valid_actions(self, cell=None):
|
||||||
|
if cell is None:
|
||||||
|
row, col, mode = self.state
|
||||||
|
else:
|
||||||
|
row, col = cell
|
||||||
|
actions = [0, 1, 2, 3]
|
||||||
|
nrows, ncols = self.maze.shape
|
||||||
|
if row == 0:
|
||||||
|
actions.remove(1)
|
||||||
|
elif row == nrows-1:
|
||||||
|
actions.remove(3)
|
||||||
|
|
||||||
|
if col == 0:
|
||||||
|
actions.remove(0)
|
||||||
|
elif col == ncols-1:
|
||||||
|
actions.remove(2)
|
||||||
|
|
||||||
|
if row>0 and self.maze[row-1,col] == 0.0:
|
||||||
|
actions.remove(1)
|
||||||
|
if row<nrows-1 and self.maze[row+1,col] == 0.0:
|
||||||
|
actions.remove(3)
|
||||||
|
|
||||||
|
if col>0 and self.maze[row,col-1] == 0.0:
|
||||||
|
actions.remove(0)
|
||||||
|
if col<ncols-1 and self.maze[row,col+1] == 0.0:
|
||||||
|
actions.remove(2)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def show(qmaze):
|
||||||
|
global win_target
|
||||||
|
win_target_row, win_target_col = win_target
|
||||||
|
plt.grid('on')
|
||||||
|
nrows, ncols = qmaze.maze.shape
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.set_xticks(np.arange(0.5, nrows, 1))
|
||||||
|
ax.set_yticks(np.arange(0.5, ncols, 1))
|
||||||
|
ax.set_xticklabels([])
|
||||||
|
ax.set_yticklabels([])
|
||||||
|
canvas = np.copy(qmaze.maze)
|
||||||
|
for row,col in qmaze.visited:
|
||||||
|
canvas[row,col] = 0.6
|
||||||
|
rat_row, rat_col, _ = qmaze.state
|
||||||
|
canvas[rat_row, rat_col] = 0.3 # rat cell
|
||||||
|
canvas[win_target_row, win_target_col] = 0.9 # cheese cell
|
||||||
|
img = plt.imshow(canvas, interpolation='none', cmap='gray')
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def save_pic(qmaze):
|
||||||
|
global file_name_num
|
||||||
|
global win_target
|
||||||
|
win_target_row, win_target_col = win_target
|
||||||
|
plt.grid('on')
|
||||||
|
nrows, ncols = qmaze.maze.shape
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.set_xticks(np.arange(0.5, nrows, 1))
|
||||||
|
ax.set_yticks(np.arange(0.5, ncols, 1))
|
||||||
|
ax.set_xticklabels([])
|
||||||
|
ax.set_yticklabels([])
|
||||||
|
canvas = np.copy(qmaze.maze)
|
||||||
|
for row,col in qmaze.visited:
|
||||||
|
canvas[row,col] = 0.6
|
||||||
|
rat_row, rat_col, _ = qmaze.state
|
||||||
|
canvas[rat_row, rat_col] = 0.3 # rat cell
|
||||||
|
canvas[win_target_row, win_target_col] = 0.9 # cheese cell
|
||||||
|
plt.imshow(canvas, interpolation='none', cmap='gray')
|
||||||
|
plt.savefig(str(file_name_num) + ".png")
|
||||||
|
file_name_num += 1
|
||||||
|
|
||||||
|
def output_route(qmaze):
|
||||||
|
global win_target
|
||||||
|
win_target_row, win_target_col = win_target
|
||||||
|
print(qmaze._maze)
|
||||||
|
|
||||||
|
def play_game(model, qmaze, rat_cell):
|
||||||
|
qmaze.reset(rat_cell)
|
||||||
|
envstate = qmaze.observe()
|
||||||
|
while True:
|
||||||
|
prev_envstate = envstate
|
||||||
|
# get next action
|
||||||
|
q = model.predict(prev_envstate)
|
||||||
|
action = np.argmax(q[0])
|
||||||
|
|
||||||
|
# apply action, get rewards and new state
|
||||||
|
envstate, reward, game_status = qmaze.act(action)
|
||||||
|
if game_status == 'win':
|
||||||
|
return True
|
||||||
|
elif game_status == 'lose':
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def completion_check(model, qmaze):
|
||||||
|
for cell in qmaze.free_cells:
|
||||||
|
if not qmaze.valid_actions(cell):
|
||||||
|
return False
|
||||||
|
if not play_game(model, qmaze, cell):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Experience(object):
|
||||||
|
def __init__(self, model, max_memory=100, discount=0.9):
|
||||||
|
self.model = model
|
||||||
|
self.max_memory = max_memory
|
||||||
|
self.discount = discount
|
||||||
|
self.memory = list()
|
||||||
|
self.num_actions = model.output_shape[-1]
|
||||||
|
|
||||||
|
def remember(self, episode):
|
||||||
|
# episode = [envstate, action, reward, envstate_next, game_over]
|
||||||
|
# memory[i] = episode
|
||||||
|
# envstate == flattened 1d maze cells info, including rat cell (see method: observe)
|
||||||
|
self.memory.append(episode)
|
||||||
|
if len(self.memory) > self.max_memory:
|
||||||
|
del self.memory[0]
|
||||||
|
|
||||||
|
def predict(self, envstate):
|
||||||
|
return self.model.predict(envstate)[0]
|
||||||
|
|
||||||
|
def get_data(self, data_size=10):
|
||||||
|
env_size = self.memory[0][0].shape[1] # envstate 1d size (1st element of episode)
|
||||||
|
mem_size = len(self.memory)
|
||||||
|
data_size = min(mem_size, data_size)
|
||||||
|
inputs = np.zeros((data_size, env_size))
|
||||||
|
targets = np.zeros((data_size, self.num_actions))
|
||||||
|
for i, j in enumerate(np.random.choice(range(mem_size), data_size, replace=False)):
|
||||||
|
envstate, action, reward, envstate_next, game_over = self.memory[j]
|
||||||
|
inputs[i] = envstate
|
||||||
|
# There should be no target values for actions not taken.
|
||||||
|
targets[i] = self.predict(envstate)
|
||||||
|
# Q_sa = derived policy = max quality env/action = max_a' Q(s', a')
|
||||||
|
Q_sa = np.max(self.predict(envstate_next))
|
||||||
|
if game_over:
|
||||||
|
targets[i, action] = reward
|
||||||
|
else:
|
||||||
|
# reward + gamma * max_a' Q(s', a')
|
||||||
|
targets[i, action] = reward + self.discount * Q_sa
|
||||||
|
return inputs, targets
|
||||||
|
|
||||||
|
def qtrain(model, maze, **opt):
|
||||||
|
global epsilon
|
||||||
|
n_epoch = opt.get('n_epoch', 15000)
|
||||||
|
max_memory = opt.get('max_memory', 1000)
|
||||||
|
data_size = opt.get('data_size', 50)
|
||||||
|
weights_file = opt.get('weights_file', "")
|
||||||
|
name = opt.get('name', 'model')
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
# If you want to continue training from a previous model,
|
||||||
|
# just supply the h5 file name to weights_file option
|
||||||
|
if weights_file:
|
||||||
|
print("loading weights from file: %s" % (weights_file,))
|
||||||
|
model.load_weights(weights_file)
|
||||||
|
|
||||||
|
# Construct environment/game from numpy array: maze (see above)
|
||||||
|
qmaze = Qmaze(maze)
|
||||||
|
|
||||||
|
# Initialize experience replay object
|
||||||
|
experience = Experience(model, max_memory=max_memory)
|
||||||
|
|
||||||
|
win_history = [] # history of win/lose game
|
||||||
|
n_free_cells = len(qmaze.free_cells)
|
||||||
|
hsize = qmaze.maze.size//2 # history window size
|
||||||
|
win_rate = 0.0
|
||||||
|
imctr = 1
|
||||||
|
pre_episodes = 2**31 - 1
|
||||||
|
|
||||||
|
for epoch in range(n_epoch):
|
||||||
|
loss = 0.0
|
||||||
|
#rat_cell = random.choice(qmaze.free_cells)
|
||||||
|
#rat_cell = (0, 0)
|
||||||
|
rat_cell = (12, 12)
|
||||||
|
|
||||||
|
qmaze.reset(rat_cell)
|
||||||
|
game_over = False
|
||||||
|
|
||||||
|
# get initial envstate (1d flattened canvas)
|
||||||
|
envstate = qmaze.observe()
|
||||||
|
|
||||||
|
n_episodes = 0
|
||||||
|
while not game_over:
|
||||||
|
valid_actions = qmaze.valid_actions()
|
||||||
|
if not valid_actions: break
|
||||||
|
prev_envstate = envstate
|
||||||
|
# Get next action
|
||||||
|
if np.random.rand() < epsilon:
|
||||||
|
action = random.choice(valid_actions)
|
||||||
|
else:
|
||||||
|
action = np.argmax(experience.predict(prev_envstate))
|
||||||
|
|
||||||
|
# Apply action, get reward and new envstate
|
||||||
|
envstate, reward, game_status = qmaze.act(action)
|
||||||
|
if game_status == 'win':
|
||||||
|
print("win")
|
||||||
|
win_history.append(1)
|
||||||
|
game_over = True
|
||||||
|
# save_pic(qmaze)
|
||||||
|
if n_episodes <= pre_episodes:
|
||||||
|
# output_route(qmaze)
|
||||||
|
print(qmaze.visited)
|
||||||
|
with open('res.data', 'wb') as filehandle:
|
||||||
|
pickle.dump(qmaze.visited, filehandle)
|
||||||
|
pre_episodes = n_episodes
|
||||||
|
|
||||||
|
elif game_status == 'lose':
|
||||||
|
print("lose")
|
||||||
|
win_history.append(0)
|
||||||
|
game_over = True
|
||||||
|
# save_pic(qmaze)
|
||||||
|
else:
|
||||||
|
game_over = False
|
||||||
|
|
||||||
|
# Store episode (experience)
|
||||||
|
episode = [prev_envstate, action, reward, envstate, game_over]
|
||||||
|
experience.remember(episode)
|
||||||
|
n_episodes += 1
|
||||||
|
|
||||||
|
# Train neural network model
|
||||||
|
inputs, targets = experience.get_data(data_size=data_size)
|
||||||
|
h = model.fit(
|
||||||
|
inputs,
|
||||||
|
targets,
|
||||||
|
epochs=8,
|
||||||
|
batch_size=16,
|
||||||
|
verbose=0,
|
||||||
|
)
|
||||||
|
loss = model.evaluate(inputs, targets, verbose=0)
|
||||||
|
|
||||||
|
|
||||||
|
if len(win_history) > hsize:
|
||||||
|
win_rate = sum(win_history[-hsize:]) / hsize
|
||||||
|
|
||||||
|
dt = datetime.datetime.now() - start_time
|
||||||
|
t = format_time(dt.total_seconds())
|
||||||
|
|
||||||
|
template = "Epoch: {:03d}/{:d} | Loss: {:.4f} | Episodes: {:d} | Win count: {:d} | Win rate: {:.3f} | time: {}"
|
||||||
|
print(template.format(epoch, n_epoch-1, loss, n_episodes, sum(win_history), win_rate, t))
|
||||||
|
# we simply check if training has exhausted all free cells and if in all
|
||||||
|
# cases the agent won
|
||||||
|
if win_rate > 0.9 : epsilon = 0.05
|
||||||
|
train_max = 192
|
||||||
|
# print(sum(win_history[-192*1.5:]))
|
||||||
|
# print(192)
|
||||||
|
if sum(win_history[-192:]) >= 192:
|
||||||
|
print("Reached 100%% win rate at epoch: %d" % (epoch,))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save trained model weights and architecture, this will be used by the visualization code
|
||||||
|
h5file = name + ".h5"
|
||||||
|
json_file = name + ".json"
|
||||||
|
model.save_weights(h5file, overwrite=True)
|
||||||
|
with open(json_file, "w") as outfile:
|
||||||
|
json.dump(model.to_json(), outfile)
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
dt = datetime.datetime.now() - start_time
|
||||||
|
seconds = dt.total_seconds()
|
||||||
|
t = format_time(seconds)
|
||||||
|
print('files: %s, %s' % (h5file, json_file))
|
||||||
|
print("n_epoch: %d, max_mem: %d, data: %d, time: %s" % (epoch, max_memory, data_size, t))
|
||||||
|
return seconds
|
||||||
|
|
||||||
|
# This is a small utility for printing readable time strings:
|
||||||
|
def format_time(seconds):
|
||||||
|
if seconds < 400:
|
||||||
|
s = float(seconds)
|
||||||
|
return "%.1f seconds" % (s,)
|
||||||
|
elif seconds < 4000:
|
||||||
|
m = seconds / 60.0
|
||||||
|
return "%.2f minutes" % (m,)
|
||||||
|
else:
|
||||||
|
h = seconds / 3600.0
|
||||||
|
return "%.2f hours" % (h,)
|
||||||
|
|
||||||
|
def build_model(maze, lr=0.001):
|
||||||
|
model = Sequential()
|
||||||
|
model.add(Dense(maze.size, input_shape=(maze.size,)))
|
||||||
|
model.add(PReLU())
|
||||||
|
model.add(Dense(maze.size))
|
||||||
|
model.add(PReLU())
|
||||||
|
model.add(Dense(num_actions))
|
||||||
|
model.compile(optimizer='adam', loss='mse')
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Table:
|
||||||
|
def __init__(self, coordinate_i, coordinate_j):
|
||||||
|
self.coordinate_i = coordinate_i
|
||||||
|
self.coordinate_j = coordinate_j
|
||||||
|
change_value(coordinate_i, coordinate_j, 2, 0.)
|
||||||
|
def get_destination_coor(self):
|
||||||
|
return [self.coordinate_i, self.coordinate_j-1]
|
||||||
|
|
||||||
|
class Kitchen:
|
||||||
|
def __init__(self, coordinate_i, coordinate_j):
|
||||||
|
self.coordinate_i = coordinate_i
|
||||||
|
self.coordinate_j = coordinate_j
|
||||||
|
change_value(coordinate_i, coordinate_j, 3, 0.)
|
||||||
|
|
||||||
|
if __name__== "__main__":
|
||||||
|
|
||||||
|
def change_value(i, j, width, n):
|
||||||
|
for r in range (i, i+width):
|
||||||
|
for c in range (j, j+width):
|
||||||
|
grid[r][c] = n
|
||||||
|
|
||||||
|
grid = [[1 for x in range(16)] for y in range(16)]
|
||||||
|
table1 = Table(2, 2)
|
||||||
|
table2 = Table (2,7)
|
||||||
|
table3 = Table(2, 12)
|
||||||
|
table4 = Table(7, 2)
|
||||||
|
table5 = Table(7, 7)
|
||||||
|
table6 = Table(7, 12)
|
||||||
|
table7 = Table(12, 2)
|
||||||
|
table8 = Table(12, 7)
|
||||||
|
|
||||||
|
|
||||||
|
kitchen = Kitchen(13, 13)
|
||||||
|
maze = np.array(grid)
|
||||||
|
|
||||||
|
# print(maze)
|
||||||
|
# maze = np.array([
|
||||||
|
# [ 1., 0., 1., 1., 1., 1., 1., 1.],
|
||||||
|
# [ 1., 1., 1., 0., 0., 1., 0., 1.],
|
||||||
|
# [ 1., 1., 1., 1., 1., 1., 0., 1.],
|
||||||
|
# [ 1., 1., 1., 1., 0., 0., 1., 1.],
|
||||||
|
# [ 1., 0., 0., 0., 1., 1., 1., 1.],
|
||||||
|
# [ 1., 0., 1., 1., 1., 1., 1., 1.],
|
||||||
|
# [ 1., 1., 1., 0., 1., 1., 1., 1.]
|
||||||
|
# ])
|
||||||
|
# print(maze)
|
||||||
|
|
||||||
|
|
||||||
|
# qmaze = Qmaze(maze)
|
||||||
|
# show(qmaze)
|
||||||
|
|
||||||
|
model = build_model(maze)
|
||||||
|
qtrain(model, maze, epochs=1000, max_memory=8*maze.size, data_size=32)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user