#pragma once #include #include #include namespace at { /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that /// we can easily view it as a multidimensional array. /// /// Like ArrayRef, this class does not own the underlying data, it is expected /// to be used in situations where the data resides in some other buffer. /// /// This is intended to be trivially copyable, so it should be passed by /// value. /// /// For now, 2D only (so the copies are actually cheap, without having /// to write a SmallVector class) and contiguous only (so we can /// return non-strided ArrayRef on index). /// /// P.S. dimension 0 indexes rows, dimension 1 indexes columns template class MatrixRef { public: typedef size_t size_type; private: /// Underlying ArrayRef ArrayRef arr; /// Stride of dim 0 (outer dimension) size_type stride0; // Stride of dim 1 is assumed to be 1 public: /// Construct an empty Matrixref. /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} /// Construct an MatrixRef from an ArrayRef and outer stride. /*implicit*/ MatrixRef(ArrayRef arr, size_type stride0) : arr(arr), stride0(stride0) { TORCH_CHECK( arr.size() % stride0 == 0, "MatrixRef: ArrayRef size ", arr.size(), " not divisible by stride ", stride0) } /// @} /// @name Simple Operations /// @{ /// empty - Check if the matrix is empty. bool empty() const { return arr.empty(); } const T* data() const { return arr.data(); } /// size - Get size a dimension size_t size(size_t dim) const { if (dim == 0) { return arr.size() / stride0; } else if (dim == 1) { return stride0; } else { TORCH_CHECK( 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1"); } } size_t numel() const { return arr.size(); } /// equals - Check for element-wise equality. bool equals(MatrixRef RHS) const { return stride0 == RHS.stride0 && arr.equals(RHS.arr); } /// @} /// @name Operator Overloads /// @{ ArrayRef operator[](size_t Index) const { return arr.slice(Index * stride0, stride0); } /// Disallow accidental assignment from a temporary. /// /// The declaration here is extra complicated so that "arrayRef = {}" /// continues to select the move assignment operator. template std::enable_if_t, MatrixRef>& operator=( U&& Temporary) = delete; /// Disallow accidental assignment from a temporary. /// /// The declaration here is extra complicated so that "arrayRef = {}" /// continues to select the move assignment operator. template std::enable_if_t, MatrixRef>& operator=( std::initializer_list) = delete; }; } // end namespace at