#pragma once #include #include #include #include #include #include namespace at { /* * The basic strategy for apply is as follows: * * 1. Starting with the outermost index, loop until we reach a dimension where * the data is no longer contiguous, i.e. the stride at that dimension is not * equal to the size of the tensor defined by the outer dimensions. Let's call * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then * A is equal to the entire Tensor. Let's call the inner tensor B. * * 2. We loop through the indices in B, starting at its outermost dimension. For * example, if B is a 2x2 matrix, then we do: * * B[0][0] * B[0][1] * B[1][0] * B[1][1] * * We set the offset into the underlying storage as (storageOffset + stride_B * * index_B), i.e. basically we compute the offset into the storage as we would * normally for a Tensor. But because we are guaranteed the subsequent data is * contiguous in memory, we can simply loop for sizeof(A) iterations and perform * the operation, without having to follow the order described by the strides of * A. * * 3. As an optimization, we merge dimensions of A that are contiguous in * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, * then the first two dimensions can be merged for the purposes of APPLY, * reducing the number of nested loops. */ inline Tensor sort_strides(Tensor& tensor_) { IntArrayRef strides = tensor_.strides(); std::vector indices; indices.reserve(tensor_.ndimension()); for (const auto i : c10::irange(tensor_.ndimension())) { indices.push_back(i); } std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) { return strides[i1] > strides[i2]; }); Tensor tensor = tensor_.permute(indices); return tensor; } template struct strided_tensor_iter_fixed { public: T* data_ = NULL; int64_t dim_ = 0; int64_t counter_[N] = {0}; int64_t sizes_[N] = {0}; int64_t strides_[N] = {0}; strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; void operator=(strided_tensor_iter_fixed const& x) = delete; strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default; strided_tensor_iter_fixed( Tensor& tensor, C10_UNUSED bool sort_strides = false) : data_(tensor.data_ptr()) { std::memset(counter_, 0, sizeof(int64_t) * N); if (tensor.dim() > 0) { std::memcpy( sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t)); std::memcpy( strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t)); } dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension())); } }; template struct strided_tensor_iter { private: public: T* data_ = NULL; int64_t dim_; std::vector counter_; std::vector sizes_; std::vector strides_; strided_tensor_iter(strided_tensor_iter const&) = delete; void operator=(strided_tensor_iter const& x) = delete; strided_tensor_iter(strided_tensor_iter&&) = default; strided_tensor_iter(Tensor& tensor) : data_(tensor.data_ptr()), dim_(tensor.ndimension()), counter_(dim_, 0), sizes_(tensor.sizes().vec()), strides_(tensor.strides().vec()) { dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_)); } }; inline bool _all_equal_numel(at::ArrayRef tensors) { if (tensors.empty()) return true; int64_t all_numel = tensors[0].numel(); for (const auto i : c10::irange(1, tensors.size())) { if (tensors[i].numel() != all_numel) return false; } return true; } inline std::string _all_equal_numel_error(at::ArrayRef tensors) { std::ostringstream oss; oss << "inconsistent tensor size, expected "; for (size_t i = 0; i < tensors.size() - 1; i++) { oss << tensors[i].sizes() << ", "; } oss << "and " << tensors[tensors.size() - 1].sizes() << " to have the same number of elements, but got "; for (size_t i = 0; i < tensors.size() - 1; i++) { oss << tensors[i].numel() << ", "; } oss << "and " << tensors[tensors.size() - 1].numel() << " elements respectively"; return oss.str(); } inline bool _apply_preamble(ArrayRef tensors) { checkDeviceType("CPU_tensor_apply", tensors, kCPU); checkLayout("CPU_tensor_apply", tensors, kStrided); if (!_all_equal_numel(tensors)) AT_ERROR(_all_equal_numel_error(tensors)); // An empty tensor has no elements for (auto& t : tensors) if (t.numel() == 0) return false; return true; } inline int64_t _max_dim_tensors(ArrayRef tensors) { int64_t dim = 0; for (auto& t : tensors) dim = std::max(dim, t.ndimension()); return dim; } inline void iterate(int64_t /*size*/){}; template inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) { iter.counter_[iter.dim_ - 1] += size; iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1]; iterate(size, iter_tail...); } inline bool iterate_continue() { return true; }; template inline bool iterate_continue(Arg& iter, Args&... iter_tail) { return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] && iterate_continue(iter_tail...); } inline int64_t max_iterate_size() { return std::numeric_limits::max(); }; template inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) { return std::min( (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]), max_iterate_size(iter_tail...)); } inline void iterate_overflow(){}; template inline void iterate_overflow(Arg& iter, Args&... iter_tail) { if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) { for (int64_t i = iter.dim_ - 1; i > 0; i--) { if (iter.counter_[i] == iter.sizes_[i]) { iter.counter_[i] = 0; iter.counter_[i - 1]++; iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) + iter.strides_[i - 1]; } } } iterate_overflow(iter_tail...); } inline void forward(int64_t /*offset*/){}; template inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) { int64_t multi = offset; for (int64_t i = iter.dim_ - 1; i >= 0; i--) { int64_t inc = multi % iter.sizes_[i]; multi = multi / iter.sizes_[i]; iter.data_ = iter.data_ + inc * iter.strides_[i]; iter.counter_[i] += inc; } forward(offset, iter_tail...); } inline int64_t max_dim() { return 0; } template inline int64_t max_dim(Arg& iter, Args&... iter_tail) { return std::max(iter.dim_, max_dim(iter_tail...)); } inline void apply_op(){}; template inline void apply_op( int64_t numel, int64_t offset, const Op& op, Args... iters) { // For 0-dim tensors if (numel == 1 && max_dim(iters...) == 0) { op(*iters.data_...); return; } if (offset > 0) forward(offset, iters...); // Splitting this into chunks helps the compiler create faster assembly for (int64_t i = 0; i < numel;) { for (; iterate_continue(iters...) && i < numel;) { op(*iters.data_...); iterate(1, iters...); i++; } iterate_overflow(iters...); } } /* Apply a pointwise operator to sequence of tensors The calling convention for op is a function/functor that takes the same number of pointers of type scalar as the number of given tensors. For example, to compute a = b * c, op would be of the form: [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] = b_val[0] * c_val[0]; }; */ template inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) { if (!_apply_preamble({tensor1, tensor2})) return; if (_max_dim_tensors({tensor1, tensor2}) <= 8) { apply_op( tensor1.numel(), 0, op, strided_tensor_iter_fixed(tensor1), strided_tensor_iter_fixed(tensor2)); } else { apply_op( tensor1.numel(), 0, op, strided_tensor_iter(tensor1), strided_tensor_iter(tensor2)); } } template inline void CPU_tensor_apply3( Tensor tensor1, Tensor tensor2, Tensor tensor3, const Op op) { if (!_apply_preamble({tensor1, tensor2, tensor3})) return; if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) { apply_op( tensor1.numel(), 0, op, strided_tensor_iter_fixed(tensor1), strided_tensor_iter_fixed(tensor2), strided_tensor_iter_fixed(tensor3)); } else { apply_op( tensor1.numel(), 0, op, strided_tensor_iter(tensor1), strided_tensor_iter(tensor2), strided_tensor_iter(tensor3)); } } template < typename scalar1, typename scalar2, typename scalar3, typename scalar4, typename Op> inline void CPU_tensor_apply4( Tensor tensor1, Tensor tensor2, Tensor tensor3, Tensor tensor4, const Op op) { if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4})) return; if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) { apply_op( tensor1.numel(), 0, op, strided_tensor_iter_fixed(tensor1), strided_tensor_iter_fixed(tensor2), strided_tensor_iter_fixed(tensor3), strided_tensor_iter_fixed(tensor4)); } else { apply_op( tensor1.numel(), 0, op, strided_tensor_iter(tensor1), strided_tensor_iter(tensor2), strided_tensor_iter(tensor3), strided_tensor_iter(tensor4)); } } } // namespace at