3RNN/Lib/site-packages/tensorflow/include/xla/service/transfer_manager.h
2024-05-26 19:49:15 +02:00

316 lines
14 KiB
C++

/* Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_TRANSFER_MANAGER_H_
#define XLA_SERVICE_TRANSFER_MANAGER_H_
#include <map>
#include <set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "xla/literal.h"
#include "xla/service/executable.h"
#include "xla/service/shaped_buffer.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
namespace xla {
// The TransferManager interface lets backends provide platform-specific
// mechanisms for constructing literals from given device memory handles.
// This lets each platform customize how literals are transferred to/from the
// device in terms of padding, leading dimension, etc.
class TransferManager {
public:
virtual ~TransferManager() = default;
// Returns the ID of the platform that this transfer manager acts on.
virtual se::Platform::Id PlatformId() const = 0;
// Returns the shape of the on-device representation for the given shape on
// the host. This is intended for use with ShapedBuffer where buffers are
// pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
// needing to consider device-specific behaviors.
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
// Strips off any preexisting tiling or memory space information.
// TODO(phawkins): fix clients not to including tiling or memory space
// information in shapes passed to this function and turn this into an
// assertion.
return ShapeUtil::DeviceShapeToHostShape(host_shape);
}
// Base class for specifying platform specific transfer metadata that can be
// used to tell the underlying implementation to perform specific optimization
// to a transfer. Actual metadata passed to supported transfer methods should
// subclass this class.
class TransferMetadata {
public:
virtual ~TransferMetadata() = default;
};
// Returns a literal containing the data held in the given ShapedBuffer
// using the provided executor. This operation is performed synchronously
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
//
// Optionally caller can specify platform-specific transfer metadata that
// tells the actual implementation to do something special.
StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
const TransferMetadata* transfer_metadata = nullptr);
Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
const MutableBorrowingLiteral& literal,
const TransferMetadata* transfer_metadata = nullptr);
// Begins transferring a literal containing the data held in the given
// ShapedBuffer using the provided executor.
//
// This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued. 'done' is invoked with the result when
// complete.
//
// device_buffer is copied by reference and must live at least until done() is
// invoked.
//
// Optionally caller can specify platform-specific transfer metadata that
// tells the actual implementation to do something special.
virtual void TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal, std::function<void(Status)> done,
const TransferMetadata* transfer_metadata) = 0;
void TransferLiteralFromDevice(se::Stream* stream,
const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal,
std::function<void(Status)> done) {
return TransferLiteralFromDevice(stream, device_buffer, literal, done,
nullptr);
}
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
// but need not have the same layout.
//
// This operation is performed synchronously without waiting for any other
// operation on a stream to complete. This function should be avoided in favor
// of the asynchronous version below.
//
// Optionally caller can specify platform-specific transfer metadata that
// tells the actual implementation to do something special.
Status TransferLiteralToDevice(
se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer,
const TransferMetadata* transfer_metadata = nullptr);
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
// but need not have the same layout.
//
// This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued, and may return before the transfer has
// completed.
//
// The caller may free the data structures 'literal' and 'device_buffer'
// immediately after this function returns, however their constituent buffers
// on both host and device must remain valid until the enqueued transfer has
// completed on 'stream'.
//
// Optionally caller can specify platform-specific transfer metadata that
// tells the actual implementation to do something special.
virtual Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer,
const TransferMetadata* transfer_metadata) = 0;
Status TransferLiteralToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const ShapedBuffer& device_buffer) {
return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
nullptr);
}
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to
// transfer an array at a known address.
//
// Optionally caller can specify platform-specific transfer metadata that
// tells the actual implementation to do something special.
Status TransferArrayToDevice(
se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest,
const TransferMetadata* transfer_metadata = nullptr);
Status TransferArrayToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest,
const TransferMetadata* transfer_metadata = nullptr);
StatusOr<Literal> TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source,
const TransferMetadata* transfer_metadata = nullptr);
// Read from a device buffer and update the dynamic dimension sizes of
// `host_shape` and `device_shape`. The function takes in bounded dynamic
// shapes, and returns static shapes with dynamic shapes updated.
// The shape of the buffer also have to be compatible with the host shape and
// device shape.
virtual Status ReadDynamicShapes(se::Stream* stream,
const ShapedBuffer* device_buffer,
Shape* device_shape);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) = 0;
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor. The shape and layout are determined by the
// shape and layout of `literal`.
virtual Status TransferLiteralFromOutfeed(
se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
absl::Span<se::StreamExecutor* const> executor) = 0;
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
// ShapedBuffer is array-shaped this method does nothing.
Status WriteTupleIndexTables(se::Stream* stream,
const ShapedBuffer& device_buffer);
Status WriteTupleIndexTablesAsync(se::Stream* stream,
const ShapedBuffer& device_buffer);
// Writes a tuple index buffer for the root of 'device_buffer', which must
// be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
// rather than writing all subbuffers. This method is always asynchronous.
Status WriteRootTupleIndexTable(se::Stream* stream,
const ShapedBuffer& device_buffer);
Status WriteRootTupleIndexTable(
se::Stream* stream,
const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree);
// Determines the byte size requirement for the given shape on the underlying
// architecture. This will be used to allocate an appropriately sized memory
// region for a host-to-device transfer.
virtual int64_t GetByteSizeRequirement(const Shape& shape) const = 0;
// Chooses a compact layout for 'shape', ignoring any existing layout on
// 'shape'. What "reasonable" means is left up to the backend. The
// intended use case is to choose a layout that avoids excessive padding on
// devices that have tiled memory architectures.
// The default implementation always picks a default (major-to-minor) layout.
// Fails if 'shape' cannot be represented by the device.
virtual StatusOr<Shape> ChooseCompactLayoutForShape(
const Shape& host_shape) const;
// For the given shape, chooses a layout for infeed. The returned shape
// has the same dimensions as the original shape, and only the layout is
// changed.
virtual Shape ChooseGoodInfeedLayout(const Shape& shape) const;
typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn;
// Allocates a ScopedShapedBuffer which can hold data with the given on-host
// shape. The on-device shape may be different as indicated by
// HostShapeToDeviceShape.
StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
int device_ordinal,
DeviceShapeRepresentationFn shape_representation_fn = nullptr);
// The given ShapedBuffer holds a handle to allocated memory, but it is not
// in the general case legal to immediately copy or access that allocated
// memory because queued operations on the device may alias that memory.
// Memory ordering is enforced by the Stream's happens-before relationship
// which allows eager deallocation and reallocation of buffers host-side even
// if the device hasn't finished with them.
//
// In certain cases, it can be known that a ShapedBuffer does not have any
// conflicting accesses on the device and thus is eligible to be accessed at
// any time from the host.
//
// This function returns true if device_buffer can be accessed immediately
// without waiting for the Stream's previously enqueued items. This only
// returns true if all subbuffers in device_buffer can be accessed
// immediately.
virtual bool CanShapedBufferBeAccessedNow(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
return false;
}
// Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
virtual bool CanBufferBeAccessedNow(
se::StreamExecutor* executor,
const se::DeviceMemoryBase& device_buffer) const {
return false;
}
/////
// The TransferManager class also serves as a point to register objects for
// the various platforms.
// Registers the TransferManager singleton for the platform kind. This is
// assumed to be a singleton, so no ownership is transferred.
//
// Precondition: a platform kind must not be registered more than once.
typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
static void RegisterTransferManager(
se::Platform::Id platform_id,
TransferManagerCreationFunction transfer_manager);
// Returns the transfer manager singleton pointer if it is available for the
// given platform, or an error status if it is not.
static StatusOr<TransferManager*> GetForPlatform(
const se::Platform* platform);
// Writes the given device-memory pointers in 'elements' to the given region
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
private:
// The mutex that guards the platform-to-transfer manager map.
static absl::Mutex platform_transfer_manager_mutex_;
// State kept for each kind of TransferManager. Registration functions
// set up creation_function, and then we use that to lazily create
// "manager" the first time GetForPlatform is invoked for a particular id.
struct State {
std::unique_ptr<TransferManager> manager;
TransferManagerCreationFunction creation_function = nullptr;
};
// Map from platform kind to transfer manager singleton.
static absl::flat_hash_map<se::Platform::Id, State>*
GetPlatformTransferManagers();
};
} // namespace xla
#endif // XLA_SERVICE_TRANSFER_MANAGER_H_