159 lines
5.9 KiB
Plaintext
159 lines
5.9 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2020-11-27T10:33:47.924609Z",
|
||
|
"start_time": "2020-11-27T10:33:47.300615Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%load_ext autoreload\n",
|
||
|
"%autoreload 2\n",
|
||
|
"from model import DQNetworkImageSensor\n",
|
||
|
"from actor import Actor\n",
|
||
|
"from torch_snippets import *\n",
|
||
|
"from collections import deque"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2020-11-27T10:39:56.962607Z",
|
||
|
"start_time": "2020-11-27T10:39:34.850120Z"
|
||
|
},
|
||
|
"code_folding": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import gym\n",
|
||
|
"import gym_carla\n",
|
||
|
"import carla\n",
|
||
|
"params = {\n",
|
||
|
" 'number_of_vehicles': 10,\n",
|
||
|
" 'number_of_walkers': 0,\n",
|
||
|
" 'display_size': 384, # screen size of bird-eye render\n",
|
||
|
" 'max_past_step': 1, # the number of past steps to draw\n",
|
||
|
" 'dt': 0.1, # time interval between two frames\n",
|
||
|
" 'discrete': True, # whether to use discrete control space\n",
|
||
|
" 'discrete_acc': [-3.0, 0, 3], # discrete value of accelerations\n",
|
||
|
" 'discrete_steer': [-0.2, 0.0, 0.2], # discrete value of steering angles\n",
|
||
|
" 'continuous_accel_range': [-3.0, 3.0], # continuous acceleration range\n",
|
||
|
" 'continuous_steer_range': [-0.3, 0.3], # continuous steering angle range\n",
|
||
|
" 'ego_vehicle_filter': 'vehicle.lincoln*', # filter for defining ego vehicle\n",
|
||
|
" 'port': 2000, # connection port\n",
|
||
|
" 'town': 'Town03', # which town to simulate\n",
|
||
|
" 'task_mode': 'random', # mode of the task, [random, roundabout (only for Town03)]\n",
|
||
|
" 'max_time_episode': 1000, # maximum timesteps per episode\n",
|
||
|
" 'max_waypt': 12, # maximum number of waypoints\n",
|
||
|
" 'obs_range': 32, # observation range (meter)\n",
|
||
|
" 'lidar_bin': 0.125, # bin size of lidar sensor (meter)\n",
|
||
|
" 'd_behind': 12, # distance behind the ego vehicle (meter)\n",
|
||
|
" 'out_lane_thres': 2.0, # threshold for out of lane\n",
|
||
|
" 'desired_speed': 8, # desired speed (m/s)\n",
|
||
|
" 'max_ego_spawn_times': 200, # maximum times to spawn ego vehicle\n",
|
||
|
" 'display_route': True, # whether to render the desired route\n",
|
||
|
" 'pixor_size': 64, # size of the pixor labels\n",
|
||
|
" 'pixor': False, # whether to output PIXOR observation\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"# Set gym-carla environment\n",
|
||
|
"env = gym.make('carla-v0', params=params)\n",
|
||
|
"preprocess = lambda im: im.transpose(2,0,1) / 255. # torch.Tensor(im).permute(1,2,0) / 255.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2020-11-27T10:42:04.437292Z",
|
||
|
"start_time": "2020-11-27T10:39:56.964484Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"load_path = 'fast-car-v2.pth'\n",
|
||
|
"save_path = 'fast-car-v2.1.pth'\n",
|
||
|
"\n",
|
||
|
"actor = Actor()\n",
|
||
|
"if load_path is not None:\n",
|
||
|
" actor.qnetwork_local.load_state_dict(torch.load(load_path))\n",
|
||
|
" actor.qnetwork_target.load_state_dict(torch.load(load_path))\n",
|
||
|
"else:\n",
|
||
|
" pass\n",
|
||
|
"\n",
|
||
|
"n_episodes = 1000\n",
|
||
|
"log = Report(n_episodes)\n",
|
||
|
"def dqn(n_episodes=n_episodes, max_t=1000, eps_start=0.1, eps_end=0.01, eps_decay=0.995):\n",
|
||
|
" scores = [] # list containing scores from each episode\n",
|
||
|
" scores_window = deque(maxlen=100) # last 100 scores\n",
|
||
|
" eps = eps_start\n",
|
||
|
" # initialsize epsilon\n",
|
||
|
" for i_episode in range(1, n_episodes+1):\n",
|
||
|
" state = env.reset()\n",
|
||
|
" image, lidar, sensor = state['camera'], state['lidar'], state['state']\n",
|
||
|
" image, lidar = preprocess(image), preprocess(lidar)\n",
|
||
|
" state_dict = {'image': image, 'lidar': lidar, 'sensor': sensor}\n",
|
||
|
" score = 0\n",
|
||
|
" for t in range(max_t):\n",
|
||
|
" action = actor.act(state_dict, eps)\n",
|
||
|
" # _action = torch.argmax(action[0].cpu().detach())\n",
|
||
|
" next_state, reward, done, _ = env.step(action)\n",
|
||
|
" image, lidar, sensor = next_state['camera'], next_state['lidar'], next_state['state']\n",
|
||
|
" image, lidar = preprocess(image), preprocess(lidar)\n",
|
||
|
" next_state_dict = {'image': image, 'lidar': lidar, 'sensor': sensor}\n",
|
||
|
" actor.step(state_dict, action, reward, next_state_dict, done)\n",
|
||
|
" state_dict = next_state_dict\n",
|
||
|
" score += reward\n",
|
||
|
" if done:\n",
|
||
|
" break\n",
|
||
|
" scores_window.append(score) # save most recent score\n",
|
||
|
" scores.append(score) # save most recent score\n",
|
||
|
" eps = max(eps_end, eps_decay*eps) # decrease epsilon\n",
|
||
|
" # print('\\rEpisode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end=\"\")\n",
|
||
|
" log.record(i_episode, score=score, end='\\r')\n",
|
||
|
" if i_episode % 100 == 0:\n",
|
||
|
" log.record(i_episode, mean_score=np.mean(scores_window))\n",
|
||
|
" torch.save(actor.qnetwork_local.state_dict(), save_path)\n",
|
||
|
" return scores\n",
|
||
|
"\n",
|
||
|
"dqn()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|