3d gradient descent plot

This commit is contained in:
patrycjalazna 2021-06-28 00:02:15 +02:00
parent d37090af7f
commit 067212226b

View File

@ -582,6 +582,92 @@
"plot_cost_function(logs_5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wizualizacja metody najszybszego spadku dla wielomianu drugiego stopnia"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"degree = 1\n",
"initial_theta = np.matrix([0] * (degree + 1)).reshape(degree + 1, 1)\n",
"m, n_plus_1 = data_matrix.shape\n",
"n = n_plus_1 - 1\n",
"X = (np.ones((m, 1)))\n",
"\n",
"for i in range(1, degree + 1):\n",
" Xn = np.power(data_matrix[:, 0:n], i)\n",
" Xn /= np.amax(Xn, axis=0)\n",
" X = np.concatenate((X, Xn), axis=1)\n",
"\n",
"X = np.matrix(X).reshape(m, degree * n + 1)\n",
"Y = np.matrix(data_matrix[:, -1])\n",
"\n",
"\n",
"steepest_descent_deg_2, logs_deg_2 = steepest_descent(X, Y, initial_theta, epochs = 60)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cost_history = [row[0] for row in logs_deg_2]\n",
"all_thetas = [row[1] for row in logs_deg_2]\n",
"theta0_history = [row[0].item() for row in all_thetas]\n",
"theta1_history = [row[1].item() for row in all_thetas]\n",
"\n",
"cost_history = np.array(cost_history)\n",
"theta0_history = np.array(theta0_history)\n",
"theta1_history = np.array(theta1_history)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"theta0_vals = theta0_history\n",
"theta1_vals = theta1_history\n",
"J_vals = np.zeros((len(theta0_vals), len(theta1_vals)))\n",
"\n",
"c1=0\n",
"c2=0\n",
"pom = 0\n",
"for i in theta0_vals:\n",
" for j in theta1_vals:\n",
" t = np.array([i, j])\n",
" J_vals[c1][c2] = cost_history[pom]\n",
" c2=c2+1\n",
" c1=c1+1\n",
" pom += 1\n",
" c2=0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"\n",
"fig = go.Figure(data=[go.Surface(x=theta0_vals, y=theta1_vals, z=J_vals)])\n",
"fig.update_layout(title='Loss function for different thetas', autosize=True,\n",
" width=600, height=600, xaxis_title='theta0', \n",
" yaxis_title='theta1')\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},