248 lines
11 KiB
C
248 lines
11 KiB
C
|
/* Copyright 2017 The OpenXLA Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
==============================================================================*/
|
||
|
|
||
|
#ifndef XLA_CLIENT_CLIENT_H_
|
||
|
#define XLA_CLIENT_CLIENT_H_
|
||
|
|
||
|
#include <memory>
|
||
|
#include <string>
|
||
|
#include <utility>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "absl/types/span.h"
|
||
|
#include "xla/client/global_data.h"
|
||
|
#include "xla/client/xla_computation.h"
|
||
|
#include "xla/literal.h"
|
||
|
#include "xla/service/hlo.pb.h"
|
||
|
#include "xla/service_interface.h"
|
||
|
#include "xla/statusor.h"
|
||
|
#include "xla/types.h"
|
||
|
#include "xla/xla.pb.h"
|
||
|
#include "xla/xla_data.pb.h"
|
||
|
|
||
|
namespace xla {
|
||
|
|
||
|
// XLA service's client object -- wraps the service with convenience and
|
||
|
// lifetime-oriented methods.
|
||
|
class Client {
|
||
|
public:
|
||
|
explicit Client(ServiceInterface* stub);
|
||
|
virtual ~Client();
|
||
|
|
||
|
// Compile the computation with the given argument shapes and returns the
|
||
|
// handle to the compiled executable. The compiled executable is cached on the
|
||
|
// service, and the returned handle can be used for execution without
|
||
|
// re-compile.
|
||
|
// * The shape and layout of the arguments being executed with will affect how
|
||
|
// the computation is compiled. If argument_shapes is empty, the parameters'
|
||
|
// shape and layout will be used in the compilation.
|
||
|
// * If execution_options is not nullptr, these options are passed to the
|
||
|
// service to affect how it compiles our computation. (The pointer does not
|
||
|
// need to live beyond this call.)
|
||
|
// * If execution_options.device_handles should be empty. If you need
|
||
|
// non-empty device handles, call 'Execute' instead.
|
||
|
//
|
||
|
// TODO(b/122731460): This call caches the resulting Executable in the Service
|
||
|
// *forever*. If you're only going to run the computation once, you may want
|
||
|
// to call the Execute(const XlaComputation&) overload. If you're going to
|
||
|
// run the computation more than once but you want control over when the
|
||
|
// Executable is unloaded, use the LocalClient API.
|
||
|
StatusOr<ExecutionHandle> Compile(
|
||
|
const XlaComputation& computation,
|
||
|
absl::Span<const Shape> argument_shapes,
|
||
|
const ExecutionOptions* execution_options = nullptr);
|
||
|
|
||
|
// Executes the compiled executable for the given handle with the given
|
||
|
// arguments and returns the global data that was produced from the execution.
|
||
|
// * If execution_profile is not nullptr then the pointed-to ExecutionProfile
|
||
|
// will be filled with profile data from the execution.
|
||
|
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||
|
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
|
||
|
ExecutionProfile* execution_profile = nullptr);
|
||
|
|
||
|
// Executes the computation with the given arguments and returns the global
|
||
|
// data that was produced from the execution.
|
||
|
// * If execution_options is not nullptr, these options are passed to the
|
||
|
// service to affect how it compiles our computation. (The pointer does not
|
||
|
// need to live beyond this call.)
|
||
|
// * If execution_options.device_handles is not empty, the computation is
|
||
|
// executed on the devices associated with the handles by partitioning the
|
||
|
// computation based on the attached sharding attributes. Otherwise, a
|
||
|
// device is chosen by the service.
|
||
|
// * If execution_profile is not nullptr then the pointed-to ExecutionProfile
|
||
|
// will be filled with profile data from the execution.
|
||
|
//
|
||
|
// TODO(b/122731460): The given computation is compiled and then thrown away
|
||
|
// immediately after it's run. If you want control over how long the
|
||
|
// resulting Executable lives, use the LocalClient API.
|
||
|
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||
|
const XlaComputation& computation,
|
||
|
absl::Span<GlobalData* const> arguments,
|
||
|
const ExecutionOptions* execution_options = nullptr,
|
||
|
ExecutionProfile* execution_profile = nullptr);
|
||
|
|
||
|
// A struct to represent a computation instance to be executed.
|
||
|
// * If execution_options.device_handles is not empty, the computation is
|
||
|
// executed on the devices associated with the handles by partitioning the
|
||
|
// computation based on the attached sharding attributes. Otherwise, a
|
||
|
// device is chosen by the service.
|
||
|
struct XlaComputationInstance {
|
||
|
const XlaComputation& computation;
|
||
|
std::vector<GlobalData*> arguments;
|
||
|
ExecutionOptions execution_options;
|
||
|
ExecutionProfile* execution_profile;
|
||
|
|
||
|
XlaComputationInstance(const XlaComputation& computation,
|
||
|
std::vector<GlobalData*> arguments,
|
||
|
ExecutionOptions execution_options,
|
||
|
ExecutionProfile* execution_profile)
|
||
|
: computation(computation),
|
||
|
arguments(std::move(arguments)),
|
||
|
execution_options(execution_options),
|
||
|
execution_profile(execution_profile) {}
|
||
|
};
|
||
|
|
||
|
// Executes a list XlaComputationInstances and returns global data produced
|
||
|
// from each computation.
|
||
|
//
|
||
|
StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
|
||
|
absl::Span<const XlaComputationInstance> computations);
|
||
|
|
||
|
// Requests device_count device handles available on the target. The returned
|
||
|
// device handles are used to specify the devices to execute the computations
|
||
|
// (see ExecuteParallel) or to transfer data (see TransferToServer or
|
||
|
// TransferToInfeed).
|
||
|
StatusOr<std::vector<DeviceHandle>> GetDeviceHandles(int64_t device_count);
|
||
|
|
||
|
// Transfer the global data provided to this client process, which is
|
||
|
// returned in the provided literal. Use sparingly to avoid transfer
|
||
|
// overheads.
|
||
|
//
|
||
|
// If shape_with_layout is not nullptr, it points to a shape whose layout will
|
||
|
// be the layout of the returned literal.
|
||
|
StatusOr<Literal> Transfer(const GlobalData& data,
|
||
|
const Shape* shape_with_layout = nullptr);
|
||
|
|
||
|
// Transfer the given literal to the server. This allocates memory on the
|
||
|
// device and copies the literal's contents over. Returns a global data handle
|
||
|
// that can be used to refer to this value from the client.
|
||
|
//
|
||
|
// If device_handle is not nullptr, data is transferred to the associated
|
||
|
// device (and its replicas if replication is enabled). Otherwise, data is
|
||
|
// transferred to the default device (and its replicas).
|
||
|
StatusOr<std::unique_ptr<GlobalData>> TransferToServer(
|
||
|
const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr);
|
||
|
|
||
|
// Transfer the given literal to the Infeed interface of the device.
|
||
|
//
|
||
|
// device_handle and replica_id together specify a particular device; a device
|
||
|
// assigned for the given replica_id among the replicas that the given device
|
||
|
// handle belongs to.
|
||
|
Status TransferToInfeed(const LiteralSlice& literal, int64_t replica_id = 0,
|
||
|
const DeviceHandle* device_handle = nullptr);
|
||
|
|
||
|
// Transfers from the Outfeed of the device.
|
||
|
//
|
||
|
// device_handle and replica_id together specify a particular device; a device
|
||
|
// assigned for the given replica_id among the replicas that the given device
|
||
|
// handle belongs to.
|
||
|
StatusOr<Literal> TransferFromOutfeed(
|
||
|
const Shape* shape_with_layout, int64_t replica_id = 0,
|
||
|
const DeviceHandle* device_handle = nullptr);
|
||
|
|
||
|
// Resets the device, clearing all existing state on the device.
|
||
|
Status ResetDevice();
|
||
|
|
||
|
// Executes the computation with the given arguments and transfers the result
|
||
|
// to the client as a literal. Parameters are defined the same as for
|
||
|
// Execute() and Transfer().
|
||
|
StatusOr<Literal> ExecuteAndTransfer(
|
||
|
const XlaComputation& computation,
|
||
|
absl::Span<GlobalData* const> arguments,
|
||
|
const ExecutionOptions* execution_options = nullptr,
|
||
|
ExecutionProfile* execution_profile = nullptr);
|
||
|
|
||
|
// Computes the value of the given computation using a non-optimized
|
||
|
// interpreter on the host.
|
||
|
//
|
||
|
// The computation must not depend on any parameters, or on stateful operators
|
||
|
// such as `RngNormal` or `Infeed`.
|
||
|
//
|
||
|
// This functionality can be useful when translating a computation into XLA
|
||
|
// where something that looked dynamic is required by XLA to be specified as a
|
||
|
// constant. E.g. the source computation (outside of XLA) may include a
|
||
|
// dynamic computation of the shape of something and ComputeConstant lets you
|
||
|
// determine what the value of that computation is in the case where the value
|
||
|
// can be determined at compile time.
|
||
|
//
|
||
|
// If output_layout is non-null, then the output of the computation will be
|
||
|
// stored using that layout.
|
||
|
StatusOr<Literal> ComputeConstant(
|
||
|
const XlaComputation& computation,
|
||
|
const Layout* output_layout = nullptr) const;
|
||
|
|
||
|
// Unregister the memory for the given GlobalData on the device.
|
||
|
Status Unregister(const GlobalData& data);
|
||
|
|
||
|
// Returns a vector of global data handles that point to the tuple elements.
|
||
|
StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
|
||
|
const GlobalData& data);
|
||
|
|
||
|
// Retrieves the statistics of the given computation.
|
||
|
StatusOr<ComputationStats> GetComputationStats(
|
||
|
const XlaComputation& computation,
|
||
|
const DebugOptions& debug_options) const;
|
||
|
|
||
|
// Returns the Shape of the given array specified by 'data'. The shape
|
||
|
// includes the Layout of the array as it is stored on the service.
|
||
|
StatusOr<Shape> GetShape(const GlobalData& data);
|
||
|
|
||
|
// As above, but returns the shape of the provided computation (parameter
|
||
|
// types/names and return type).
|
||
|
StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
|
||
|
const XlaComputation& computation);
|
||
|
|
||
|
// Creates a channel handle that can be used to transfer data between two
|
||
|
// computations on different devices via a pair of Send and Recv instructions.
|
||
|
StatusOr<ChannelHandle> CreateChannelHandle();
|
||
|
|
||
|
// Create a channel for communicating with the host via a SendtoHost or
|
||
|
// RecvFromHost operation.
|
||
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle();
|
||
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle();
|
||
|
|
||
|
StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
|
||
|
|
||
|
ServiceInterface* stub() { return stub_; }
|
||
|
|
||
|
private:
|
||
|
// Returns the execution statistics (e.g., gflop/s) as a string from the
|
||
|
// ExecutionProfile returned from an execution of the computation.
|
||
|
StatusOr<std::string> ExecutionStatsAsString(
|
||
|
const XlaComputation& computation, const ExecutionProfile& profile);
|
||
|
|
||
|
StatusOr<ChannelHandle> CreateChannelHandleByType(
|
||
|
ChannelHandle::ChannelType type);
|
||
|
|
||
|
ServiceInterface* stub_; // Stub that this client is connected on.
|
||
|
|
||
|
Client(const Client&) = delete;
|
||
|
Client& operator=(const Client&) = delete;
|
||
|
};
|
||
|
|
||
|
} // namespace xla
|
||
|
|
||
|
#endif // XLA_CLIENT_CLIENT_H_
|