Computer_Vision/Chapter16/Deep_Q_Learning_Cart_Pole_balancing.ipynb

303 lines
33 KiB
Plaintext
Raw Normal View History

2024-02-13 03:34:51 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Deep_Q_Learning_Cart_Pole_balancing.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter16/Deep_Q_Learning_Cart_Pole_balancing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EAsxgBMbnOlT"
},
"source": [
"import gym\n",
"import numpy as np\n",
"import cv2\n",
"from collections import deque\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import random\n",
"from collections import namedtuple, deque\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3ANvz1Z_nWDz"
},
"source": [
"env = gym.make('CartPole-v1')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NXyrPGJSnXpp"
},
"source": [
"class DQNetwork(nn.Module):\n",
" def __init__(self, state_size, action_size):\n",
" super(DQNetwork, self).__init__()\n",
" \n",
" self.fc1 = nn.Linear(state_size, 24)\n",
" self.fc2 = nn.Linear(24, 24)\n",
" self.fc3 = nn.Linear(24, action_size)\n",
" \n",
" def forward(self, state): \n",
" x = F.relu(self.fc1(state))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Fxxt2zRGnczv"
},
"source": [
"class Agent():\n",
" def __init__(self, state_size, action_size):\n",
" \n",
" self.state_size = state_size\n",
" self.action_size = action_size\n",
" self.seed = random.seed(0)\n",
"\n",
" ## hyperparameters\n",
" self.buffer_size = 2000\n",
" self.batch_size = 64\n",
" self.gamma = 0.99\n",
" self.lr = 0.0025\n",
" self.update_every = 4 \n",
"\n",
" # Q-Network\n",
" self.local = DQNetwork(state_size, action_size).to(device)\n",
" self.optimizer = optim.Adam(self.local.parameters(), lr=self.lr)\n",
"\n",
" # Replay memory\n",
" self.memory = deque(maxlen=self.buffer_size) \n",
" self.experience = namedtuple(\"Experience\", field_names=[\"state\", \"action\", \"reward\", \"next_state\", \"done\"])\n",
" self.t_step = 0\n",
" def step(self, state, action, reward, next_state, done):\n",
" # Save experience in replay memory\n",
" self.memory.append(self.experience(state, action, reward, next_state, done)) \n",
" # Learn every update_every time steps.\n",
" self.t_step = (self.t_step + 1) % self.update_every\n",
" if self.t_step == 0:\n",
" # If enough samples are available in memory, get random subset and learn\n",
" if len(self.memory) > self.batch_size:\n",
" experiences = self.sample_experiences()\n",
" self.learn(experiences, self.gamma)\n",
" def act(self, state, eps=0.):\n",
" # Epsilon-greedy action selection\n",
" if random.random() > eps:\n",
" state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
" self.local.eval()\n",
" with torch.no_grad():\n",
" action_values = self.local(state)\n",
" self.local.train()\n",
" return np.argmax(action_values.cpu().data.numpy())\n",
" else:\n",
" return random.choice(np.arange(self.action_size))\n",
" def learn(self, experiences, gamma): \n",
" states, actions, rewards, next_states, dones = experiences\n",
" # Get expected Q values from local model\n",
" Q_expected = self.local(states).gather(1, actions)\n",
"\n",
" # Get max predicted Q values (for next states) from local model\n",
" Q_targets_next = self.local(next_states).detach().max(1)[0].unsqueeze(1)\n",
" # Compute Q targets for current states \n",
" Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))\n",
" \n",
" # Compute loss\n",
" loss = F.mse_loss(Q_expected, Q_targets)\n",
"\n",
" # Minimize the loss\n",
" self.optimizer.zero_grad()\n",
" loss.backward()\n",
" self.optimizer.step()\n",
" def sample_experiences(self):\n",
" experiences = random.sample(self.memory, k=self.batch_size) \n",
" states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)\n",
" actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)\n",
" rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)\n",
" next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)\n",
" dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) \n",
" return (states, actions, rewards, next_states, dones)\n",
"agent = Agent(env.observation_space.shape[0], env.action_space.n)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ov7oTYsYnowd"
},
"source": [
"scores = [] # list containing scores from each episode\n",
"scores_window = deque(maxlen=100) # last 100 scores\n",
"n_episodes=5000\n",
"max_t=5000\n",
"eps_start=1.0\n",
"eps_end=0.001\n",
"eps_decay=0.9995\n",
"eps = eps_start"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mHOK_FWapNn7",
"outputId": "9a7484a8-2a53-4647-ec5d-d4249ade4ed1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 893
}
},
"source": [
"for i_episode in range(1, n_episodes+1):\n",
" state = env.reset()\n",
" state_size = env.observation_space.shape[0]\n",
" state = np.reshape(state, [1, state_size])\n",
" score = 0\n",
" for i in range(max_t):\n",
" action = agent.act(state, eps)\n",
" next_state, reward, done, _ = env.step(action)\n",
" next_state = np.reshape(next_state, [1, state_size])\n",
" reward = reward if not done or score == 499 else -10\n",
" agent.step(state, action, reward, next_state, done)\n",
" state = next_state\n",
" score += reward\n",
" if done:\n",
" break \n",
" scores_window.append(score) # save most recent score \n",
" scores.append(score) # save most recent score\n",
" eps = max(eps_end, eps_decay*eps) # decrease epsilon\n",
" print('\\rEpisode {}\\tReward {} \\tAverage Score: {:.2f} \\tEpsilon: {}'.format(i_episode,score,np.mean(scores_window), eps), end=\"\")\n",
" if i_episode % 100 == 0:\n",
" print('\\rEpisode {}\\tAverage Score: {:.2f} \\tEpsilon: {}'.format(i_episode, np.mean(scores_window), eps))\n",
" if i_episode>10 and np.mean(scores[-10:])>450:\n",
" break\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"plt.plot(scores)\n",
"plt.title('Scores over increasing episodes')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Episode 100\tAverage Score: 11.83 \tEpsilon: 0.951217530242334\n",
"Episode 200\tAverage Score: 16.20 \tEpsilon: 0.9048147898403269\n",
"Episode 300\tAverage Score: 15.99 \tEpsilon: 0.8606756897186528\n",
"Episode 400\tAverage Score: 23.94 \tEpsilon: 0.8186898039137951\n",
"Episode 500\tAverage Score: 26.66 \tEpsilon: 0.7787520933134615\n",
"Episode 600\tAverage Score: 31.67 \tEpsilon: 0.7407626428726788\n",
"Episode 700\tAverage Score: 27.90 \tEpsilon: 0.7046264116491338\n",
"Episode 800\tAverage Score: 39.22 \tEpsilon: 0.6702529950324074\n",
"Episode 900\tAverage Score: 31.03 \tEpsilon: 0.637556398572254\n",
"Episode 1000\tAverage Score: 47.30 \tEpsilon: 0.606454822840097\n",
"Episode 1100\tAverage Score: 30.47 \tEpsilon: 0.5768704587855094\n",
"Episode 1200\tAverage Score: 26.48 \tEpsilon: 0.548729293075715\n",
"Episode 1300\tAverage Score: 25.61 \tEpsilon: 0.5219609229311034\n",
"Episode 1400\tAverage Score: 20.64 \tEpsilon: 0.49649837999353363\n",
"Episode 1500\tAverage Score: 25.36 \tEpsilon: 0.4722779627867691\n",
"Episode 1600\tAverage Score: 34.30 \tEpsilon: 0.44923907734991153\n",
"Episode 1700\tAverage Score: 28.26 \tEpsilon: 0.4273240856451275\n",
"Episode 1800\tAverage Score: 43.74 \tEpsilon: 0.406478161360422\n",
"Episode 1900\tAverage Score: 42.58 \tEpsilon: 0.3866491527467055\n",
"Episode 2000\tAverage Score: 36.56 \tEpsilon: 0.3677874521460121\n",
"Episode 2100\tAverage Score: 52.12 \tEpsilon: 0.34984587188445015\n",
"Episode 2200\tAverage Score: 56.85 \tEpsilon: 0.3327795262194029\n",
"Episode 2300\tAverage Score: 65.41 \tEpsilon: 0.31654571904563433\n",
"Episode 2400\tAverage Score: 186.93 \tEpsilon: 0.3011038370793723\n",
"Episode 2500\tAverage Score: 212.44 \tEpsilon: 0.28641524825313086\n",
"Episode 2600\tAverage Score: 246.87 \tEpsilon: 0.27244320506708813\n",
"Episode 2700\tAverage Score: 94.46 \tEpsilon: 0.25915275265522114\n",
"Episode 2800\tAverage Score: 143.46 \tEpsilon: 0.24651064133620196\n",
"Episode 2900\tAverage Score: 102.21 \tEpsilon: 0.23448524343027585\n",
"Episode 3000\tAverage Score: 56.00 \tEpsilon: 0.22304647413401948\n",
"Episode 3100\tAverage Score: 43.23 \tEpsilon: 0.21216571625502262\n",
"Episode 3200\tAverage Score: 56.39 \tEpsilon: 0.2018157486181985\n",
"Episode 3300\tAverage Score: 103.49 \tEpsilon: 0.1919706779646106\n",
"Episode 3400\tAverage Score: 118.43 \tEpsilon: 0.1826058741724434\n",
"Episode 3500\tAverage Score: 144.03 \tEpsilon: 0.17369790863805412\n",
"Episode 3509\tReward 500.0 \tAverage Score: 173.83 \tEpsilon: 0.17291782950789983"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Scores over increasing episodes')"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dd5xcVfn/389ueiGdkJ4AoUkJIQQCSBepRkG+Un4SEUXFhoASVBBUNIiKIr0XqVIDBEJJQk+FlA1JSO892U3dZLN7fn/cM7t3Zqfcmbkz987d5/167WtvOXPPM2fu/ZznPqeJMQZFURQlWpQFbYCiKIriPyruiqIoEUTFXVEUJYKouCuKokQQFXdFUZQIouKuKIoSQVTclcggIveJyI1B2+EFEekrIttEpDxoW9wUogxF5Hsi8pGf11Qy0yxoA5TMiMgJwN+ArwC1wBzgamPMlEANCxnGmB8HbYNXjDHLgHZB25FIKZWhkh4V95AjInsBrwM/AZ4HWgBfBXb5nE+5MabWz2sWEhFpZozZE5V8FMVvNCwTfg4AMMY8Y4ypNcbsNMa8bYyZGUsgIj8UkTkislVEvhCRwfb4wSIyQUQqRWS2iHzD9ZnHROReERkjItuBU0Skp4i8KCLrRWSxiPzClX6oiEwVkS0islZE/pnKYGvPAhHZJCKjRaSnPX6viPw9Ie2rInKN3U6X/80i8oKI/FdEtgDfS5LvYyLyZ7t9soisEJFrRWSdiKwWkctdaVuLyD9EZKmIVInIR/ZYfxExInKFiCwDxtn037dlvFlExopIP9e1/i0iy23ZTBORr2YqN1c+zez+BBH5k4h8bH/Ht0Wkq+s6l1lbN4rIjSKyREROT1H+LUXk7yKyzOZ5n4i0TiiX34rIBnudS1OUYVcRed3eP5tE5EMRKbPn0t1bXezvvkVEJgP7Jdh3kIi8Y685T0T+z3XubHsPbxWRlSJyXbLvqHjAGKN/If4D9gI2Ao8DZwGdEs5fCKwEjgYE2B/oBzQHFgC/xfH2TwW2Agfazz0GVAHH41TybYBpwE02/b7AIuDrNv2nwHftdjvg2BT2ngpsAAYDLYH/AB/YcycCywGx+52AnUBPa0O6/G8GaoBv2rStk+T9GPBnu30ysAf4oy2Ls4EdsfID7gYmAL2AcuA4a29/wABPAG2B1sBwW5YH47zt/h74xJXv/wO62HPXAmuAVunKzZVPM7s/AViIU5m3tvuj7LlDgG3ACbZs/m7L4vQUv8EdwGigM9AeeA34a0K5/NN+35OA7cTfF7Ey/Ctwny2/5jhvjELme+tZnLfMtsChOPfnR/ZcW5x74HJbXkfi3C+H2POrga+67o/BQT+DpfoXuAH65+FHckTlMWCFfTBHA93tubHAL5N85qtWZMpcx54BbrbbjwFPuM4dAyxLuMYNwKN2+wPgFqBrBlsfBv7m2m9nhai/FYZlwIn23A+BcR7zvxlbSaTJ2y1MJ+NUHM1c59cBx+JUDjuBI5Jcoz+O6O7rOvYmcIVrvwynouiXwo7NsWunKjeSi/vvXeevAt6y2zcBz7jOtQF2k0TcbRlvB/ZzHRsGLHaVyx6grev888CNScrwj8CrwP5e7y2cirIGOMh17i80iPt3gA8Trnc/8Ae7vQz4EbBX0M9dqf9pWKYEMMbMMcZ8zxjTG8cT6gn8y57ug+PxJdITWG6MqXMdW4rjqcZY7truB/S0r9mVIlKJ45l1t+evwPEq54rIFBE5N4W5PW0+Mdu34bx59DLO0/sscLE9fQnwlMf8E+31wkYTHy/fgVPZdAVakbzckuXVD/i3y65NOCLaC0BErrMhmyp7voPNA7yXGziCmWgr2N8ydsIYswOnTJPRDfsW5rL3LXs8xmZjzHbX/lKbRyK343job4vIIhEZ6bYnxb3VDccjX55wLkY/4JiE3/lSYB97/gKct6ylIvK+iAxL8T2VDGiDaolhjJkrIo/heDfgPET7JUm6CugjImWuh7Av8KX7cq7t5Tje3cAU+c4HLrYx1/OBF0SkS4JIxPJ1x6Pb4oQsVtpDz+CIxSgcb/1bXvJPYm8+bACqccpthoe8lgO3GmOeSkxk4+u/AU4DZhtj6kRkM474pyy3LO1dDRzoyrM1TpkmYwPOW8lXjDErU6TpJCJtXb9dX6AiMZExZitOmOlaETkUGCciU0h/b63HeTPoA8x1nYuxHHjfGPO1ZIYZpwfYcBFpDvwM562iT4rvoaRBPfeQYxufrhWR3na/D47nO9EmeQi4TkSOEof9bWPfJBzv7zci0lxETgbOw/GckzEZ2Coi14vTsFguIoeKyNE23/8nIt3sw1xpP1OX5DrPAJeLyCARaYnzSj7JGLMEwBjzOY4APQSMNcbErpU2fz+x3+ER4J/iNOKWi8gwa28y7gNuEJGvAIhIBxG50J5rjyNm64FmInITTjsJNq3XckvHC8B5InKciLTACX9Imu/2IHCHiOxtbeglIl9PSHqLiLSwldO5wP8SryUi59r7SXDaZ2qt7SnvLeP0uHoJuFlE2ojIIcAI12VfBw4Qke/azzYXkaNtA20LEblURDoYY2qALTmUlWJRcQ8/W3E83Eni9GqZiONlXQtgjPkfcCvwtE37CtDZGLMb54E7C0dM7wEuM8bMbZSDc51anId8ELCYBgHuYJOcCcwWkW3Av4GLjDE7k1znXeBG4EUcj3M/4KKEZE8Dp9v/XvP3m+uAWcAUnDDLbaR4HowxL9vzz4rTU6cCp1zBafN4C8drXYrzRuAOSXgqt3QYY2YDP8epmFfjNK6uI3V32OtxwikTrb3v4vL8ccI/m3E88KeAH6e4Lwbaz27DaRi+xxgz3sO99TOckNIanBj+o67vshU4A+eeWGXT3IbTuAvwXWCJtfvHOCEbJQdivRYURSkRRKQdzlvAQGPM4iw/ezLwX9t+o0QY9dwVpQQQkfNsmKMtTlfIWcCSYK1SwoyKu6KUBsNxwhircMIlFxl97VbSoGEZRVGUCKKeu6IoSgQJRT/3rl27mv79+wdthqIoSkkxbdq0DcaYbsnOhULc+/fvz9SpU4M2Q1EUpaQQkaWpzmlYRlEUJYKouCuKokQQFXdFUZQIouKuKIoSQVTcFUVRIogncRdnKa5ZIjJdRKbaY53FWSprvv3fyR4XEblTnGXWZopd8k1RFEUpHtl47qcYYwYZY4bY/ZHAe3b+7ffsPjgzxQ20f1cC9/plrKIoiuKNfPq5D8dZsguc9T0n4Ew1Ohxn+TaDM+VoRxHpYYxZnY+hiqLkzs7dtbxZsZo2Lco5un9nurRryduz1zBt6WaenLiULu1acNA+e1GxsoqHRxzNwT3ac+F9nzJ16WYAWjUvY+iALqzbUs1N5x7Cr56fzjEDunDC/l254KjelJclnV6+Ec9PWc7OmlpGHNe//tgrn6+kY5vmnHzg3oX46mmprTO8OG0F5w/uRbPywkap36pYTac2LRg/bz3PT13Opu27OWif9tx2weEc0aej7/l5FXeDs3qOAe43xjyAs4ZnTLDX0LAcWi/i57NeYY/FibuIXInj2dO3r3uhFkVR/ObPb3zBU5OWAXBk347870fDuPLJafXnd2zayfJNzjTzZ9/5IR9df0q9sANU19TxwZfrAbjkoUkAjJ6xitEzVlFTV8elx9QvvpWW37w4EyBO3K9+bjoAS0adk+O3y52nJy/jxlcq2LprD1ecMKBg+VTtrOHH//2s0fG5a7YyZcmmQMX9BGPMSruyyzsiEjexvzHGWOH3jK0gHgAYMmSIzl6mKAVk7ZaGdT1WV1ZnXK9wT633R7JyR02OVgXP5u27Aajcsbug+eypTb2g1HeOLswqgp7eQ2JrMRpj1gEvA0OBtSLSA8D+X2eTryR+zcPeNKyfqShKwHgNoSilTUZxF5G2ItI+to2zRFYFMJqGtRFHAK/a7dHAZbbXzLFAlcbbFSVYxKXnzcozi3s2r9KSR10R9JTjUZ7x3EtYpjvwsrNGLs2Ap40xb9lV0J8XkStw1o78P5t+DHA2zhqOO4DLfbdaUZScKc9HjZWck
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CCoDl_OVpWhl"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}