1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/client/executable_build_options.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/layout.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/shape.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 // API notes:
45 // PjRt stands for "Pretty much Just another RunTime".
46 
47 namespace xla {
48 
49 using PjRtPlatformId = uint64;
50 
51 constexpr char kCpuName[] = "cpu";
52 constexpr char kGpuName[] = "gpu";
53 constexpr char kTpuName[] = "tpu";
54 static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName);
55 static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName);
56 static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName);
57 
58 class PjRtClient;
59 
60 class PjRtDevice {
61  public:
~PjRtDevice()62   virtual ~PjRtDevice() {}
63 
64   // Return the client that owns this device.
65   virtual PjRtClient* client() const = 0;
66 
67   // Whether client can issue command to this device.
68   virtual bool IsAddressable() const = 0;
69 
70   // The ID of this device. IDs are unique among devices of this type
71   // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
72   // hosts' devices.  This is the ID that should be used in a DeviceAssignment.
73   virtual int id() const = 0;
74 
75   // The task ID of this device according to TpuTopology. This is not always
76   // identical to PjRtClient::task_id() in a multi-task setting, where each
77   // client can see devices from all tasks, but only a subset of them are
78   // addressable and have the same task_id as the client.
79   virtual int task_id() const = 0;
80 
81   // Opaque hardware ID, e.g., the CUDA device number, useful for identifying
82   // which GPU when interacting with non-JAX code. In general, not guaranteed to
83   // be dense, and -1 if undefined.
84   virtual int local_hardware_id() const = 0;
85 
86   // A vendor-dependent string that uniquely identifies the kind of device,
87   // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
88   // compatible compilation.
89   virtual absl::string_view device_kind() const = 0;
90 
91   virtual std::string DebugString() const = 0;
92 
93   // Transfer the given literal to the infeed queue.
94   virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
95 
96   // Transfer and return a value of the given shape from the outfeed queue.
97   virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
98 };
99 
100 // Forward declaration.
101 class PjRtBuffer;
102 
103 // Helper struct for cross host transfers, returned by the callback from a call
104 // to PjRtBuffer::MakeCrossHostReceiveBuffers.
105 struct PjRtCrossHostRecvBuffer {
106   // serialized_descriptor should be transmitted to the sender and passed to a
107   // call to src_buffer->CopyToRemoteDevice.
108   std::string serialized_descriptor;
109   // The buffer that will hold the result of the transfer.
110   std::unique_ptr<PjRtBuffer> buffer;
111 };
112 using PjRtCrossHostRecvNotifier =
113     std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
114 
115 struct CompileOptions {
116   // The layouts of the arguments that the computation should expect.
117   absl::optional<std::vector<Shape>> argument_layouts;
118 
119   // If true, the supplied computation expects its arguments to be wrapped in a
120   // tuple and passed as a single parameter.
121   bool parameter_is_tupled_arguments = false;
122 
123   // XLA's compilation time options.
124   ExecutableBuildOptions executable_build_options;
125 
126   // If true, the executable can be run on any device. May only be true if
127   // !executable_build_options.has_device_assignment(), so only applies to
128   // single-device executables. Beware: on GPUs, sometimes an executable
129   // compiled for one device doesn't run on another.
130   bool compile_portable_executable = false;
131 };
132 
133 class PjRtExecutable;
134 
135 // Encapsulates the state of Python session with XLA.
136 //
137 // It is the responsibility of the client of this API to keep the PjRtClient
138 // alive as long as any of the other runtime objects are alive.
139 class PjRtClient {
140  public:
141   virtual ~PjRtClient() = default;
142 
143   // Return the task id of this client. In single-task setting, always 0.
144   virtual int task_id() const = 0;
145 
146   // Return the number of devices in the entire computation. In multi-headed
147   // client setting, some are addressable by this client, some are not. In a
148   // single-client setting, this is equal to the number of addressable devices.
149   virtual int device_count() const = 0;
150 
151   // Return number of addressable devices. Addressable devices are those that
152   // the client can issue commands to.
153   virtual int addressable_device_count() const = 0;
154 
155   // Return all devices in the entire computation, including addressable and
156   // non-addressable devices.
157   virtual absl::Span<PjRtDevice* const> devices() const = 0;
158 
159   // Return only addressable devices.
160   virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
161 
162   // Lookup any PjRtDevice for a given PjRtDevice::id().
163   virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
164 
165   // Return an addressable PjRtDevice for a given
166   // PjRtDevice::local_hardware_id().
167   virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
168       int local_hardware_id) const = 0;
169 
170   // Return an ID that identifies the platform (CPU/GPU/TPU).
171   virtual PjRtPlatformId platform_id() const = 0;
172 
173   // Returns a string that identifies the platform (CPU/GPU/TPU).
174   virtual absl::string_view platform_name() const = 0;
175 
176   // Return a device-specific default device assignment, e.g., GPU and TPU may
177   // be different.
178   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
179       int num_replicas, int num_partitions) const = 0;
180 
181   // Returns a backend-specific HLO cost analysis visitor.
182   virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0;
183 
184   // Compile `computation` with given `options`.
185   virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
186       const XlaComputation& computation, CompileOptions options) = 0;
187 
188   // Generates a unique fingerprint for `executable`, may be absl::nullopt.
189   virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
190       const PjRtExecutable& executable) const = 0;
191 
192   // Creates a buffer on the device without initializing or copying any data.
193   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
194       const Shape& shape, PjRtDevice* device) = 0;
195 
196   // Describes the semantics the caller to BufferFromHostBuffer expects from the
197   // runtime, in a total order from most restrictive to least restrictive.
198   enum class HostBufferSemantics {
199     // The runtime may not hold references to `data` after the call to
200     // `BufferFromHostBuffer` completes. The caller promises that `data` is
201     // immutable and will not be freed only for the duration of the
202     // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
203     // before `BufferFromHostBuffer` returns.
204     kImmutableOnlyDuringCall,
205 
206     // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
207     // returns while the runtime completes a transfer to the device. The caller
208     // promises not to mutate or free `data` until the transfer completes, at
209     // which point the runtime will call `on_done_with_host_buffer`. It is also
210     // correct to wait on the host (directly or indirectly) for the buffer's
211     // definition event to complete.
212     kImmutableUntilTransferCompletes,
213 
214     // The PjRtBuffer may alias `data` internally and the runtime may use the
215     // `data` contents as long as the buffer is alive. The caller promises to
216     // keep `data` alive and not to mutate its contents as long as the buffer is
217     // alive; to notify the caller that the buffer may be freed, the runtime
218     // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
219     // non-CPU platforms this acts identically to
220     // kImmutableUntilTransferCompletes.
221     kZeroCopy,
222   };
223   // on_done_with_host_buffer is optional and may be null.
224   // on_done_with_host_buffer will be called iff an OK status is returned.
225   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
226       const void* data, const Shape& shape,
227       HostBufferSemantics host_buffer_semantics,
228       std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
229 
230   // Note that literal must remain in scope until the transfer has completed, so
231   // the caller should, for example, wait for BlockHostUntilReady() completes on
232   // the return value before letting literal go out of scope.
233   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
234       const LiteralSlice& literal, PjRtDevice* device) = 0;
235 
236   // Creates a PjRtBuffer that is a non-owned view of an on-device
237   // buffer (typically allocated by another library).
238   // on_delete_callback is called when the PjRtBuffer is done with the on-device
239   // buffer. The buffer may be mutated, for example, if the buffer is donated
240   // to an Execute operation.
241   // TODO(phawkins): Currently this API assumes the buffer is ready to use
242   // immediately on the device. Extend it to support, for example, waiting for a
243   // CUDA stream/event.
244   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
245       void* device_ptr, const Shape& shape, PjRtDevice* device,
246       std::function<void()> on_delete_callback) = 0;
247 
248   // Asynchronously makes a vector of PjRtBuffers that can be used to receive
249   // cross host transfers using `client` on `device'. `shapes` must be the exact
250   // shapes, with identical layouts, corresponding to the buffers that will be
251   // sent. When resources for the transfer are available, notifier will be
252   // called with a vector of PjRtCrossHostRecvBuffer structs, one for each
253   // shape in `shapes`. Each struct contains a buffer that will contain the
254   // received value, and an opaque string that should be transmitted to the
255   // sending host and used in a call to CopyToRemoteDevice. None of the recv
256   // buffers will become ready until *all* of the sends have completed.
257   virtual void MakeCrossHostReceiveBuffers(
258       absl::Span<const Shape> shapes, PjRtDevice* device,
259       PjRtCrossHostRecvNotifier&& notifier) = 0;
260 
261   // Create ChannelHandles for XLA send/recv.
262   virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
263   virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
264   virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
265 };
266 
267 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
268 // can be either valid or invalid. An invalid buffer is one that has never been
269 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or
270 // by donating it to a computation that aliases an input parameter to an
271 // output). We allow PjRtBuffer objects to outlive the underlying device
272 // buffers so we can decouple buffer lifetimes from the corresponding Python
273 // references if needed. Thread-safe.
274 class PjRtBuffer {
275  public:
276   virtual ~PjRtBuffer() = default;
277 
278   virtual const Shape& on_device_shape() const = 0;
279   virtual PjRtDevice* device() const = 0;
280   virtual PjRtClient* client() const = 0;
281 
282   // Returns the size of the on-device representation of this buffer in bytes.
283   virtual int64 OnDeviceSizeInBytes() const = 0;
284 
285   // ExternalReference is a potentially long-lived reference held while a buffer
286   // is being shared by an external framework, e.g., NumPy. A client acquires an
287   // external reference by calling PjRtBuffer::AcquireExternalReference() and
288   // releases it by deleting the ExternalReference. The external framework
289   // should not modify the underlying buffer unless it is confident via its own
290   // synchronization that modifications do not race with reads from the
291   // PjRtBuffer.
292   class ExternalReference {
293    public:
294     virtual ~ExternalReference() = 0;
295     // Return opaque device memory pointer to root buffer.
OpaqueDeviceMemoryDataPointer()296     void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; }
297 
298    protected:
299     void* data_ptr_;
300   };
301   virtual StatusOr<std::unique_ptr<ExternalReference>>
302   AcquireExternalReference() = 0;
303 
304   // Copies the buffer's value into `literal`. Calls `on_ready` when the value
305   // (or an error) is ready. The transfer respects the layout of `literal`; to
306   // specify a particular layout, set the layout before calling `ToLiteral`.
307   virtual void ToLiteral(MutableLiteralBase* literal,
308                          std::function<void(Status)> on_ready) = 0;
309 
310   // Synchronous overload of ToLiteral, as a convenience.
ToLiteral(MutableLiteralBase * literal)311   Status ToLiteral(MutableLiteralBase* literal) {
312     absl::Notification done;
313     Status status;
314     ToLiteral(literal, [&](Status s) {
315       status = std::move(s);
316       done.Notify();
317     });
318     done.WaitForNotification();
319     return status;
320   }
321 
322   // Convenience synchronous overload that allocates a literal with a default
323   // layout.
ToLiteral()324   StatusOr<std::shared_ptr<Literal>> ToLiteral() {
325     auto literal = std::make_shared<Literal>(
326         ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
327     TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
328     return literal;
329   }
330 
331   // Drops the buffer's reference to its associated device memory, leaving the
332   // buffer in an invalid state. The memory will be freed lazily when all async
333   // operations using the buffer have completed, according to the allocation
334   // semantics of the underlying platform. Delete may briefly block if another
335   // thread is in the process of enqueuing an operation on this buffer, but it
336   // will never block for a stream operation to complete. If an external
337   // framework holds a reference to the TrackedDeviceBuffer via
338   // GetBufferWithExternalReference, the memory will not be freed until the
339   // external framework drops the reference.
340   virtual void Delete() = 0;
341 
342   // Similar to Delete, drops the buffer's reference to its associated device
343   // memory, leaving the buffer in an invalid state, but transfers the device
344   // memory ownership out via an ExternalReference rather than
345   // freeing the device memory, so that another framework can take ownership of
346   // it. A return value of nullptr indicates that PjRtBuffer has been
347   // deleted. The buffer returned from Release may be safely dropped at any time
348   // even if it still has pending async operations. The client should call
349   // BlockHostUntilReady before calling ReleaseDeviceMemoryOwnership with
350   // wait_for_operations_to_complete=false, to ensure that the host has
351   // synchronized past any outstanding write operations to the buffer. If
352   // wait_for_operations_to_complete=true the host will block until any
353   // potentially outstanding asynchronous operations have completed before
354   // returning, in which case it is safe to read or mutate the returned buffer.
355   // If the buffer was shared via an external reference it is the client's
356   // responsibility that accesses via that reference do not interfere with
357   // accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
358   virtual StatusOr<std::unique_ptr<ExternalReference>>
359   ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
360 
361   // True if and only if Delete or Release has previously been called.
362   virtual bool IsDeleted() = 0;
363 
364   // Copies the buffer to device `dst_device`, performing a d2d transfer when
365   // `dst_device` is sharing the same Client, and performing a d2h and h2d copy
366   // if `dst_device` lives on a different Client.
367   // Returns an error if the buffer is already on dst_device.
368   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
369       PjRtDevice* dst_device) = 0;
370 
371   // Copies the buffer to the remote device encoded in serialized_descriptor.
372   // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
373   // remote host's destination device. MakeCrossHostReceiveBuffers takes an
374   // array of shapes to construct the destination buffers, and a callback
375   // supplies an array containing both the destination buffers, and a serialized
376   // descriptor for each buffer. For each destination buffer there should be a
377   // matching call to src->CopyToRemoteDevice on a remote host for a src buffer
378   // of the corresponding shape. serialized_descriptor is the string returned by
379   // the callback along with the corresponding destination buffer.
380   virtual Status CopyToRemoteDevice(
381       absl::string_view serialized_descriptor) = 0;
382 
383   // Blocks the host until the buffer's value has been computed and is ready for
384   // immediate use on the device. Useful in particular for timing benchmarks.
385   virtual Status BlockHostUntilReady() = 0;
386 
387   // Whether this buffer is on CPU and thus allows for certain optimizations.
388   virtual bool IsOnCpu() const = 0;
389 };
390 
391 class ExecuteContext {
392  public:
393   virtual ~ExecuteContext() = default;
394 };
395 
396 struct ExecuteOptions {
397   // If true, the client must pass a single PjRtBuffer which contains all of
398   // the arguments as a single XLA tuple, otherwise each argument must be
399   // passed in its own PjRtBuffer. May only be true if the executable was
400   // compiled with parameter_is_tupled_arguments==true.
401   bool arguments_are_tupled = false;
402   // If true, the computation must return a tuple, which will be destructured
403   // into its elements.
404   bool untuple_result = false;
405   // If non-zero, identifies this execution as part of a potentially
406   // multi-device launch. This can be used to detect scheduling errors, e.g. if
407   // multi-host programs are launched in different orders on different hosts,
408   // the launch IDs may be used by the runtime to detect the mismatch.
409   int32 launch_id = 0;
410   // If non-null, an opaque context passed to an execution that may be used to
411   // supply additional arguments to a derived class of PjRtExecutable.
412   const ExecuteContext* context = nullptr;
413 };
414 
415 // Represents a compiled computation that can be executed given handles to
416 // device-allocated literals. If any input/output alias has been specified in
417 // the computation, the parameter containing the input buffer will be donated
418 // when passed to the execution.
419 class PjRtExecutable {
420  public:
421   virtual ~PjRtExecutable() = default;
422 
423   virtual PjRtClient* client() const = 0;
424 
425   // Unique name for this executable, e.g., HloModule name.
426   virtual absl::string_view name() const = 0;
427 
428   virtual int num_replicas() const = 0;
429 
430   virtual int num_partitions() const = 0;
431 
432   virtual int64 SizeOfGeneratedCodeInBytes() const = 0;
433 
434   virtual const DeviceAssignment& device_assignment() const = 0;
435 
436   // The replica and partition indices of device_assignment to be run by this
437   // client. On single-host platforms without partitioning, this is all replicas
438   // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
439   // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
440   // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
441   struct LogicalDeviceIds {
442     int replica;
443     int partition;
444   };
445   virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
446       const = 0;
447 
448   // An addressable_device is one which the client can issue commands to.
449   // addressable_devices()[i] is the Device to which
450   // addressable_device_logical_ids()[i] is assigned.
451   virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
452 
453   // Return an HloModule (optimized) per partition.
454   virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
455       const = 0;
456 
457   // Executes on devices addressable by the client. Requires executable has a
458   // device_assignment and all devices in the device_assignment are addressable
459   // by the client.
460   virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
461   Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
462           const ExecuteOptions& options) = 0;
463 
464   // Execute the assigned replica/partition on a given `device`. Requires
465   // executable has a device_assignment, `device` is present in the
466   // device_assignment and addressable by the client.
467   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
468       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
469       const ExecuteOptions& options) = 0;
470 
471   // Execute on a given `device`. Requires `device` to be addressable by client.
472   // Requires executable has exactly 1 replica and 1 partition and no
473   // device_assignment (thus portable).
474   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
475       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
476       const ExecuteOptions& options) = 0;
477 
478   // Asynchronously free resources after the last execution completes.
479   virtual void Delete() = 0;
480 };
481 
482 }  // namespace xla
483 
484 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
485