{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Operations_on_tensors.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.577164Z", "start_time": "2020-09-25T19:27:01.286695Z" }, "id": "TqJn6S6MaXSJ", "outputId": "365578bb-f483-40f4-d08f-adb8c1982840", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "import torch\n", "x = torch.tensor([[1,2,3,4], [5,6,7,8]]) \n", "print(x * 10)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "tensor([[10, 20, 30, 40],\n", " [50, 60, 70, 80]])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.581350Z", "start_time": "2020-09-25T19:27:01.578406Z" }, "id": "oPoA4yptaY2N", "outputId": "89d72f4b-71b8-45ae-9c50-4fb1e73ae6c0", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "x = torch.tensor([[1,2,3,4], [5,6,7,8]]) \n", "y = x.add(10)\n", "print(y)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "tensor([[11, 12, 13, 14],\n", " [15, 16, 17, 18]])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.586320Z", "start_time": "2020-09-25T19:27:01.582720Z" }, "id": "fHmRXqMcadet" }, "source": [ "y = torch.tensor([2, 3, 1, 0]) # y.shape == (4)\n", "y = y.view(4,1) # y.shape == (4, 1)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.591222Z", "start_time": "2020-09-25T19:27:01.587669Z" }, "id": "rr5Gs-QMaf-H", "outputId": "28504802-789e-4372-83f8-feff86ed66e0", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "x = torch.randn(10,1,10)\n", "z1 = torch.squeeze(x, 1) # similar to np.squeeze()\n", "# The same operation can be directly performed on\n", "# x by calling squeeze and the dimension to squeeze out\n", "z2 = x.squeeze(1)\n", "assert torch.all(z1 == z2) # all the elements in both tensors are equal\n", "print('Squeeze:\\n', x.shape, z1.shape)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Squeeze:\n", " torch.Size([10, 1, 10]) torch.Size([10, 10])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.596220Z", "start_time": "2020-09-25T19:27:01.592251Z" }, "id": "jnIQNMH5ajlF", "outputId": "f53812a0-2391-4746-9af9-22b29eb8867a", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "x = torch.randn(10,10)\n", "print(x.shape)\n", "# torch.size(10,10)\n", "z1 = x.unsqueeze(0)\n", "print(z1.shape)\n", "# torch.size(1,10,10)\n", "# The same can be achieved using [None] indexing\n", "# Adding None will auto create a fake dim at the\n", "# specified axis\n", "x = torch.randn(10,10)\n", "z2, z3, z4 = x[None], x[:,None], x[:,:,None]\n", "print(z2.shape, z3.shape, z4.shape)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "torch.Size([10, 10])\n", "torch.Size([1, 10, 10])\n", "torch.Size([1, 10, 10]) torch.Size([10, 1, 10]) torch.Size([10, 10, 1])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.600694Z", "start_time": "2020-09-25T19:27:01.597443Z" }, "id": "SWxKXdP6am9D", "outputId": "8cc993b3-d461-4aa6-9bf3-c0f56af01037", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "x = torch.tensor([[1,2,3,4], [5,6,7,8]])\n", "print(torch.matmul(x, y))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "tensor([[11],\n", " [35]])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.603782Z", "start_time": "2020-09-25T19:27:01.601641Z" }, "id": "VtZmPZOEapyc", "outputId": "170d0165-e8b0-46a8-f01a-1bfa93a93e2c", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "print(x@y)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "tensor([[11],\n", " [35]])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.608841Z", "start_time": "2020-09-25T19:27:01.605190Z" }, "id": "al6kKt4dasVv", "outputId": "d0761ad3-ddce-432a-a1be-8284ffef7276", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "import torch\n", "x = torch.randn(10,10,10)\n", "z = torch.cat([x,x], axis=0) # np.concatenate()\n", "print('Cat axis 0:', x.shape, z.shape)\n", "# Cat axis 0: torch.Size([10, 10, 10]) torch.Size([20, 10, 10])\n", "z = torch.cat([x,x], axis=1) # np.concatenate()\n", "print('Cat axis 1:', x.shape, z.shape)\n", "# Cat axis 1: torch.Size([10, 10, 10]) torch.Size([10, 20, 10])" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Cat axis 0: torch.Size([10, 10, 10]) torch.Size([20, 10, 10])\n", "Cat axis 1: torch.Size([10, 10, 10]) torch.Size([10, 20, 10])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.612500Z", "start_time": "2020-09-25T19:27:01.609931Z" }, "id": "vv1DtZ2qb_qu", "outputId": "bafdaba7-c0b5-4b8d-a7bb-beb20b8cfd7e", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "x = torch.arange(25).reshape(5,5)\n", "print('Max:', x.shape, x.max()) " ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Max: torch.Size([5, 5]) tensor(24)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.620199Z", "start_time": "2020-09-25T19:27:01.613427Z" }, "id": "DO2nx2glcNPQ", "outputId": "1b7fbdb5-1f41-4bd7-8c77-3b494e90161b", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "x.max(dim=0)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.return_types.max(\n", "values=tensor([20, 21, 22, 23, 24]),\n", "indices=tensor([4, 4, 4, 4, 4]))" ] }, "metadata": { "tags": [] }, "execution_count": 10 } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.624220Z", "start_time": "2020-09-25T19:27:01.621298Z" }, "id": "3O-_2LwQcOv6", "outputId": "5b84364e-3453-4265-8e10-59ede36033e9", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "m, argm = x.max(dim=1) \n", "print('Max in axis 1:\\n', m, argm) " ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Max in axis 1:\n", " tensor([ 4, 9, 14, 19, 24]) tensor([4, 4, 4, 4, 4])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:27:01.627717Z", "start_time": "2020-09-25T19:27:01.625155Z" }, "id": "0qwAEb6BcQJB", "outputId": "d89be36a-3a97-4f3e-da1a-64cf3ff4cffa", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "x = torch.randn(10,20,30)\n", "z = x.permute(2,0,1) # np.permute()\n", "print('Permute dimensions:', x.shape, z.shape)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Permute dimensions: torch.Size([10, 20, 30]) torch.Size([30, 10, 20])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:35:03.872412Z", "start_time": "2020-09-25T19:35:03.861902Z" }, "id": "mCeCjaZo0arI", "outputId": "f8718838-01ed-4426-9ff1-49f346ffe131" }, "source": [ "dir(torch.Tensor)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['T',\n", " '__abs__',\n", " '__add__',\n", " '__and__',\n", " '__array__',\n", " '__array_priority__',\n", " '__array_wrap__',\n", " '__bool__',\n", " '__class__',\n", " '__contains__',\n", " '__cuda_array_interface__',\n", " '__deepcopy__',\n", " '__delattr__',\n", " '__delitem__',\n", " '__dict__',\n", " '__dir__',\n", " '__div__',\n", " '__doc__',\n", " '__eq__',\n", " '__float__',\n", " '__floordiv__',\n", " '__format__',\n", " '__ge__',\n", " '__getattribute__',\n", " '__getitem__',\n", " '__gt__',\n", " '__hash__',\n", " '__iadd__',\n", " '__iand__',\n", " '__idiv__',\n", " '__ifloordiv__',\n", " '__ilshift__',\n", " '__imul__',\n", " '__index__',\n", " '__init__',\n", " '__init_subclass__',\n", " '__int__',\n", " '__invert__',\n", " '__ior__',\n", " '__ipow__',\n", " '__irshift__',\n", " '__isub__',\n", " '__iter__',\n", " '__itruediv__',\n", " '__ixor__',\n", " '__le__',\n", " '__len__',\n", " '__long__',\n", " '__lshift__',\n", " '__lt__',\n", " '__matmul__',\n", " '__mod__',\n", " '__module__',\n", " '__mul__',\n", " '__ne__',\n", " '__neg__',\n", " '__new__',\n", " '__nonzero__',\n", " '__or__',\n", " '__pow__',\n", " '__radd__',\n", " '__rdiv__',\n", " '__reduce__',\n", " '__reduce_ex__',\n", " '__repr__',\n", " '__reversed__',\n", " '__rfloordiv__',\n", " '__rmul__',\n", " '__rpow__',\n", " '__rshift__',\n", " '__rsub__',\n", " '__rtruediv__',\n", " '__setattr__',\n", " '__setitem__',\n", " '__setstate__',\n", " '__sizeof__',\n", " '__str__',\n", " '__sub__',\n", " '__subclasshook__',\n", " '__truediv__',\n", " '__weakref__',\n", " '__xor__',\n", " '_backward_hooks',\n", " '_base',\n", " '_cdata',\n", " '_coalesced_',\n", " '_dimI',\n", " '_dimV',\n", " '_grad',\n", " '_grad_fn',\n", " '_indices',\n", " '_is_view',\n", " '_make_subclass',\n", " '_nnz',\n", " '_update_names',\n", " '_values',\n", " '_version',\n", " 'abs',\n", " 'abs_',\n", " 'absolute',\n", " 'absolute_',\n", " 'acos',\n", " 'acos_',\n", " 'acosh',\n", " 'acosh_',\n", " 'add',\n", " 'add_',\n", " 'addbmm',\n", " 'addbmm_',\n", " 'addcdiv',\n", " 'addcdiv_',\n", " 'addcmul',\n", " 'addcmul_',\n", " 'addmm',\n", " 'addmm_',\n", " 'addmv',\n", " 'addmv_',\n", " 'addr',\n", " 'addr_',\n", " 'align_as',\n", " 'align_to',\n", " 'all',\n", " 'allclose',\n", " 'angle',\n", " 'any',\n", " 'apply_',\n", " 'argmax',\n", " 'argmin',\n", " 'argsort',\n", " 'as_strided',\n", " 'as_strided_',\n", " 'as_subclass',\n", " 'asin',\n", " 'asin_',\n", " 'asinh',\n", " 'asinh_',\n", " 'atan',\n", " 'atan2',\n", " 'atan2_',\n", " 'atan_',\n", " 'atanh',\n", " 'atanh_',\n", " 'backward',\n", " 'baddbmm',\n", " 'baddbmm_',\n", " 'bernoulli',\n", " 'bernoulli_',\n", " 'bfloat16',\n", " 'bincount',\n", " 'bitwise_and',\n", " 'bitwise_and_',\n", " 'bitwise_not',\n", " 'bitwise_not_',\n", " 'bitwise_or',\n", " 'bitwise_or_',\n", " 'bitwise_xor',\n", " 'bitwise_xor_',\n", " 'bmm',\n", " 'bool',\n", " 'byte',\n", " 'cauchy_',\n", " 'ceil',\n", " 'ceil_',\n", " 'char',\n", " 'cholesky',\n", " 'cholesky_inverse',\n", " 'cholesky_solve',\n", " 'chunk',\n", " 'clamp',\n", " 'clamp_',\n", " 'clamp_max',\n", " 'clamp_max_',\n", " 'clamp_min',\n", " 'clamp_min_',\n", " 'clone',\n", " 'coalesce',\n", " 'conj',\n", " 'contiguous',\n", " 'copy_',\n", " 'cos',\n", " 'cos_',\n", " 'cosh',\n", " 'cosh_',\n", " 'cpu',\n", " 'cross',\n", " 'cuda',\n", " 'cummax',\n", " 'cummin',\n", " 'cumprod',\n", " 'cumsum',\n", " 'data',\n", " 'data_ptr',\n", " 'deg2rad',\n", " 'deg2rad_',\n", " 'dense_dim',\n", " 'dequantize',\n", " 'det',\n", " 'detach',\n", " 'detach_',\n", " 'device',\n", " 'diag',\n", " 'diag_embed',\n", " 'diagflat',\n", " 'diagonal',\n", " 'digamma',\n", " 'digamma_',\n", " 'dim',\n", " 'dist',\n", " 'div',\n", " 'div_',\n", " 'dot',\n", " 'double',\n", " 'dtype',\n", " 'eig',\n", " 'element_size',\n", " 'eq',\n", " 'eq_',\n", " 'equal',\n", " 'erf',\n", " 'erf_',\n", " 'erfc',\n", " 'erfc_',\n", " 'erfinv',\n", " 'erfinv_',\n", " 'exp',\n", " 'exp_',\n", " 'expand',\n", " 'expand_as',\n", " 'expm1',\n", " 'expm1_',\n", " 'exponential_',\n", " 'fft',\n", " 'fill_',\n", " 'fill_diagonal_',\n", " 'flatten',\n", " 'flip',\n", " 'fliplr',\n", " 'flipud',\n", " 'float',\n", " 'floor',\n", " 'floor_',\n", " 'floor_divide',\n", " 'floor_divide_',\n", " 'fmod',\n", " 'fmod_',\n", " 'frac',\n", " 'frac_',\n", " 'gather',\n", " 'ge',\n", " 'ge_',\n", " 'geometric_',\n", " 'geqrf',\n", " 'ger',\n", " 'get_device',\n", " 'grad',\n", " 'grad_fn',\n", " 'gt',\n", " 'gt_',\n", " 'half',\n", " 'hardshrink',\n", " 'has_names',\n", " 'histc',\n", " 'ifft',\n", " 'imag',\n", " 'index_add',\n", " 'index_add_',\n", " 'index_copy',\n", " 'index_copy_',\n", " 'index_fill',\n", " 'index_fill_',\n", " 'index_put',\n", " 'index_put_',\n", " 'index_select',\n", " 'indices',\n", " 'int',\n", " 'int_repr',\n", " 'inverse',\n", " 'irfft',\n", " 'is_coalesced',\n", " 'is_complex',\n", " 'is_contiguous',\n", " 'is_cuda',\n", " 'is_distributed',\n", " 'is_floating_point',\n", " 'is_leaf',\n", " 'is_meta',\n", " 'is_mkldnn',\n", " 'is_nonzero',\n", " 'is_pinned',\n", " 'is_quantized',\n", " 'is_same_size',\n", " 'is_set_to',\n", " 'is_shared',\n", " 'is_signed',\n", " 'is_sparse',\n", " 'isclose',\n", " 'isfinite',\n", " 'isinf',\n", " 'isnan',\n", " 'istft',\n", " 'item',\n", " 'kthvalue',\n", " 'layout',\n", " 'le',\n", " 'le_',\n", " 'lerp',\n", " 'lerp_',\n", " 'lgamma',\n", " 'lgamma_',\n", " 'log',\n", " 'log10',\n", " 'log10_',\n", " 'log1p',\n", " 'log1p_',\n", " 'log2',\n", " 'log2_',\n", " 'log_',\n", " 'log_normal_',\n", " 'log_softmax',\n", " 'logaddexp',\n", " 'logaddexp2',\n", " 'logcumsumexp',\n", " 'logdet',\n", " 'logical_and',\n", " 'logical_and_',\n", " 'logical_not',\n", " 'logical_not_',\n", " 'logical_or',\n", " 'logical_or_',\n", " 'logical_xor',\n", " 'logical_xor_',\n", " 'logsumexp',\n", " 'long',\n", " 'lstsq',\n", " 'lt',\n", " 'lt_',\n", " 'lu',\n", " 'lu_solve',\n", " 'map2_',\n", " 'map_',\n", " 'masked_fill',\n", " 'masked_fill_',\n", " 'masked_scatter',\n", " 'masked_scatter_',\n", " 'masked_select',\n", " 'matmul',\n", " 'matrix_power',\n", " 'max',\n", " 'mean',\n", " 'median',\n", " 'min',\n", " 'mm',\n", " 'mode',\n", " 'mul',\n", " 'mul_',\n", " 'multinomial',\n", " 'mv',\n", " 'mvlgamma',\n", " 'mvlgamma_',\n", " 'name',\n", " 'names',\n", " 'narrow',\n", " 'narrow_copy',\n", " 'ndim',\n", " 'ndimension',\n", " 'ne',\n", " 'ne_',\n", " 'neg',\n", " 'neg_',\n", " 'nelement',\n", " 'new',\n", " 'new_empty',\n", " 'new_full',\n", " 'new_ones',\n", " 'new_tensor',\n", " 'new_zeros',\n", " 'nonzero',\n", " 'norm',\n", " 'normal_',\n", " 'numel',\n", " 'numpy',\n", " 'orgqr',\n", " 'ormqr',\n", " 'output_nr',\n", " 'permute',\n", " 'pin_memory',\n", " 'pinverse',\n", " 'polygamma',\n", " 'polygamma_',\n", " 'pow',\n", " 'pow_',\n", " 'prelu',\n", " 'prod',\n", " 'put_',\n", " 'q_per_channel_axis',\n", " 'q_per_channel_scales',\n", " 'q_per_channel_zero_points',\n", " 'q_scale',\n", " 'q_zero_point',\n", " 'qr',\n", " 'qscheme',\n", " 'rad2deg',\n", " 'rad2deg_',\n", " 'random_',\n", " 'real',\n", " 'reciprocal',\n", " 'reciprocal_',\n", " 'record_stream',\n", " 'refine_names',\n", " 'register_hook',\n", " 'reinforce',\n", " 'relu',\n", " 'relu_',\n", " 'remainder',\n", " 'remainder_',\n", " 'rename',\n", " 'rename_',\n", " 'renorm',\n", " 'renorm_',\n", " 'repeat',\n", " 'repeat_interleave',\n", " 'requires_grad',\n", " 'requires_grad_',\n", " 'reshape',\n", " 'reshape_as',\n", " 'resize',\n", " 'resize_',\n", " 'resize_as',\n", " 'resize_as_',\n", " 'retain_grad',\n", " 'rfft',\n", " 'roll',\n", " 'rot90',\n", " 'round',\n", " 'round_',\n", " 'rsqrt',\n", " 'rsqrt_',\n", " 'scatter',\n", " 'scatter_',\n", " 'scatter_add',\n", " 'scatter_add_',\n", " 'select',\n", " 'set_',\n", " 'shape',\n", " 'share_memory_',\n", " 'short',\n", " 'sigmoid',\n", " 'sigmoid_',\n", " 'sign',\n", " 'sign_',\n", " 'sin',\n", " 'sin_',\n", " 'sinh',\n", " 'sinh_',\n", " 'size',\n", " 'slogdet',\n", " 'smm',\n", " 'softmax',\n", " 'solve',\n", " 'sort',\n", " 'sparse_dim',\n", " 'sparse_mask',\n", " 'sparse_resize_',\n", " 'sparse_resize_and_clear_',\n", " 'split',\n", " 'split_with_sizes',\n", " 'sqrt',\n", " 'sqrt_',\n", " 'square',\n", " 'square_',\n", " 'squeeze',\n", " 'squeeze_',\n", " 'sspaddmm',\n", " 'std',\n", " 'stft',\n", " 'storage',\n", " 'storage_offset',\n", " 'storage_type',\n", " 'stride',\n", " 'sub',\n", " 'sub_',\n", " 'sum',\n", " 'sum_to_size',\n", " 'svd',\n", " 'symeig',\n", " 't',\n", " 't_',\n", " 'take',\n", " 'tan',\n", " 'tan_',\n", " 'tanh',\n", " 'tanh_',\n", " 'to',\n", " 'to_dense',\n", " 'to_mkldnn',\n", " 'to_sparse',\n", " 'tolist',\n", " 'topk',\n", " 'trace',\n", " 'transpose',\n", " 'transpose_',\n", " 'triangular_solve',\n", " 'tril',\n", " 'tril_',\n", " 'triu',\n", " 'triu_',\n", " 'true_divide',\n", " 'true_divide_',\n", " 'trunc',\n", " 'trunc_',\n", " 'type',\n", " 'type_as',\n", " 'unbind',\n", " 'unflatten',\n", " 'unfold',\n", " 'uniform_',\n", " 'unique',\n", " 'unique_consecutive',\n", " 'unsqueeze',\n", " 'unsqueeze_',\n", " 'values',\n", " 'var',\n", " 'view',\n", " 'view_as',\n", " 'volatile',\n", " 'where',\n", " 'zero_']" ] }, "metadata": { "tags": [] }, "execution_count": 24 } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:35:08.396527Z", "start_time": "2020-09-25T19:35:08.394059Z" }, "id": "jhiL6isOcSJP", "outputId": "1bbe6e0f-b453-47d9-c687-a1526681aa8b" }, "source": [ "help(torch.Tensor.view)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Help on method_descriptor:\n", "\n", "view(...)\n", " view(*shape) -> Tensor\n", " \n", " Returns a new tensor with the same data as the :attr:`self` tensor but of a\n", " different :attr:`shape`.\n", " \n", " The returned tensor shares the same data and must have the same number\n", " of elements, but may have a different size. For a tensor to be viewed, the new\n", " view size must be compatible with its original size and stride, i.e., each new\n", " view dimension must either be a subspace of an original dimension, or only span\n", " across original dimensions :math:`d, d+1, \\dots, d+k` that satisfy the following\n", " contiguity-like condition that :math:`\\forall i = d, \\dots, d+k-1`,\n", " \n", " .. math::\n", " \n", " \\text{stride}[i] = \\text{stride}[i+1] \\times \\text{size}[i+1]\n", " \n", " Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape`\n", " without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a\n", " :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which\n", " returns a view if the shapes are compatible, and copies (equivalent to calling\n", " :meth:`contiguous`) otherwise.\n", " \n", " Args:\n", " shape (torch.Size or int...): the desired size\n", " \n", " Example::\n", " \n", " >>> x = torch.randn(4, 4)\n", " >>> x.size()\n", " torch.Size([4, 4])\n", " >>> y = x.view(16)\n", " >>> y.size()\n", " torch.Size([16])\n", " >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions\n", " >>> z.size()\n", " torch.Size([2, 8])\n", " \n", " >>> a = torch.randn(1, 2, 3, 4)\n", " >>> a.size()\n", " torch.Size([1, 2, 3, 4])\n", " >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension\n", " >>> b.size()\n", " torch.Size([1, 3, 2, 4])\n", " >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory\n", " >>> c.size()\n", " torch.Size([1, 3, 2, 4])\n", " >>> torch.equal(b, c)\n", " False\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "OYYOtiFn0arN" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }