26 lines
974 B
C
26 lines
974 B
C
|
#pragma once
|
||
|
#include <ATen/ATen.h>
|
||
|
#include <ATen/core/op_registration/op_registration.h>
|
||
|
#include <torch/library.h>
|
||
|
|
||
|
namespace at {
|
||
|
|
||
|
// If an operator doesn't have a batching rule implemented then we fallback
|
||
|
// to this implementation. The fallback only works on out-of-place operators
|
||
|
// that return only tensors with new memory. (e.g., no in-place operators, no
|
||
|
// view operations).
|
||
|
//
|
||
|
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
|
||
|
// them, and runs `op` on all of the corresponding slices to produce slices
|
||
|
// of the outputs. The output slices then get `torch.stack`ed to create the
|
||
|
// final returns.
|
||
|
//
|
||
|
// The performance of the fallback is not very good because it introduces an
|
||
|
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
||
|
// write batching rules for operators whenever possible.
|
||
|
void batchedTensorForLoopFallback(
|
||
|
const c10::OperatorHandle& op,
|
||
|
torch::jit::Stack* stack);
|
||
|
|
||
|
} // namespace at
|