1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64 
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66 
67 #include <cstddef>
68 #include <cstdlib>
69 #include <memory>
70 #include <string>
71 #include <vector>
72 
73 #include "absl/base/casts.h"
74 #include "absl/container/flat_hash_set.h"
75 #include "absl/container/inlined_vector.h"
76 #include "absl/memory/memory.h"
77 #include "absl/strings/str_format.h"
78 #include "absl/synchronization/mutex.h"
79 #include "absl/time/time.h"
80 #include "absl/types/optional.h"
81 #include "absl/types/span.h"
82 #include "tensorflow/compiler/xla/client/local_client.h"
83 #include "tensorflow/compiler/xla/client/xla_computation.h"
84 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
85 #include "tensorflow/compiler/xla/executable_run_options.h"
86 #include "tensorflow/compiler/xla/layout.h"
87 #include "tensorflow/compiler/xla/literal.h"
88 #include "tensorflow/compiler/xla/literal_util.h"
89 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
90 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
91 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
92 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
93 #include "tensorflow/compiler/xla/pjrt/utils.h"
94 #include "tensorflow/compiler/xla/service/executable.h"
95 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
96 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
97 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
98 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
99 #include "tensorflow/compiler/xla/shape_util.h"
100 #include "tensorflow/compiler/xla/util.h"
101 #include "tensorflow/compiler/xla/xla_data.pb.h"
102 #include "tensorflow/core/platform/cpu_info.h"
103 #include "tensorflow/core/platform/errors.h"
104 #include "tensorflow/core/platform/fingerprint.h"
105 #include "tensorflow/core/platform/mem.h"
106 #include "tensorflow/core/platform/status.h"
107 #include "tensorflow/core/platform/types.h"
108 #include "tensorflow/core/profiler/lib/connected_traceme.h"
109 #include "tensorflow/core/profiler/lib/traceme.h"
110 #include "tensorflow/core/profiler/lib/traceme_encode.h"
111 #include "tensorflow/stream_executor/device_memory.h"
112 #include "tensorflow/stream_executor/device_memory_allocator.h"
113 #include "tensorflow/stream_executor/event.h"
114 #include "tensorflow/stream_executor/host/host_platform_id.h"
115 #include "tensorflow/stream_executor/lib/statusor.h"
116 #include "tensorflow/stream_executor/stream.h"
117 
118 namespace xla {
119 
platform_id() const120 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
121   return client_->platform_id();
122 }
platform_name() const123 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
124   return client_->platform_name();
125 }
126 
GetLocalDeviceState() const127 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
128     const {
129   if (local_device_state_) {
130     return local_device_state_.get();
131   }
132   return InvalidArgument("Device %s is not a local device.", DebugString());
133 }
134 
DebugString() const135 std::string PjRtStreamExecutorDevice::DebugString() const {
136   return absl::StrCat(platform_name(), ":", id());
137 }
138 
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)139 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
140     absl::Span<const std::vector<PjRtDevice*>> devices) {
141   if (devices.empty()) {
142     return InvalidArgument(
143         "Device assignment passed to Compile() must be non-empty.");
144   }
145   if (devices[0].empty()) {
146     return InvalidArgument(
147         "Device assignment passed to Compile() must have a nonzero number of "
148         "partitions per replica; replica 0 had 0 partitions.");
149   }
150   DeviceAssignment xla_assignment(devices.size(), devices[0].size());
151   for (int replica = 0; replica < devices.size(); ++replica) {
152     if (devices[replica].size() != devices[0].size()) {
153       return InvalidArgument(
154           "Device assignment passed to Compile() has different numbers of "
155           "partitions between replicas; %d partitions for replica %d versus %d "
156           "partitions for replica 0.",
157           devices[replica].size(), replica, devices[0].size());
158     }
159     for (int partition = 0; partition < devices[replica].size(); ++partition) {
160       if (devices[0][0]->client()->platform_id() !=
161           devices[replica][partition]->client()->platform_id()) {
162         return InvalidArgument(
163             "Device assignment passed to Compile() must have devices of a "
164             "single kind, got %s for replica 0 partition 0 and %s for replica "
165             "%d partition %d.",
166             devices[0][0]->client()->platform_name(),
167             devices[replica][partition]->client()->platform_name(), replica,
168             partition);
169       }
170       xla_assignment(replica, partition) = devices[replica][partition]->id();
171     }
172   }
173   return xla_assignment;
174 }
175 
176 class CpuAllocator : public tensorflow::Allocator {
177  public:
178   CpuAllocator() = default;
179 
Name()180   std::string Name() override { return "cpu"; }
181 
AllocateRaw(size_t alignment,size_t num_bytes)182   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
183     return tensorflow::port::AlignedMalloc(num_bytes, alignment);
184   }
DeallocateRaw(void * ptr)185   void DeallocateRaw(void* ptr) override {
186     return tensorflow::port::AlignedFree(ptr);
187   }
188 };
189 
DefaultThreadPoolSize()190 static int DefaultThreadPoolSize() {
191   // Google's CI system exposes an environment variable NPROC that describes
192   // a CPU reservation for tests.
193   // TODO(phawkins): expose a better thought-out set of knobs to control
194   // parallelism.
195   const char* nproc_str = std::getenv("NPROC");
196   int nproc = 0;
197   if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
198     return std::max(0, nproc);
199   }
200   return tensorflow::port::MaxParallelism();
201 }
202 
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int task_id,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)203 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
204     std::string platform_name, LocalClient* client,
205     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
206     std::unique_ptr<se::DeviceMemoryAllocator> allocator,
207     std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
208     bool should_stage_host_to_device_transfers,
209     std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
210     : platform_id_(tensorflow::Fingerprint64(platform_name)),
211       platform_name_(std::move(platform_name)),
212       client_(client),
213       host_memory_allocator_(std::move(host_memory_allocator)),
214       owned_devices_(std::move(devices)),
215       task_id_(task_id),
216       owned_allocator_(std::move(allocator)),
217       should_stage_host_to_device_transfers_(
218           should_stage_host_to_device_transfers),
219       gpu_run_options_(std::move(gpu_run_options)),
220       thread_pool_(
221           tensorflow::Env::Default(), "pjrt_thread_pool",
222           std::max<int>(DefaultThreadPoolSize(), client->device_count())) {
223   if (owned_allocator_ != nullptr) {
224     allocator_ = owned_allocator_.get();
225   } else {
226     allocator_ = client_->backend().memory_allocator();
227   }
228 
229   if (!host_memory_allocator_) {
230     host_memory_allocator_ = std::make_unique<CpuAllocator>();
231   }
232 
233   for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
234        owned_devices_) {
235     devices_.push_back(device.get());
236     CHECK(id_to_device_.insert({device->id(), device.get()}).second)
237         << "Duplicate device id: " << device->id();
238 
239     if (device->IsAddressable()) {
240       int idx = device->local_hardware_id();
241       if (idx >= addressable_devices_.size()) {
242         addressable_devices_.resize(idx + 1);
243       }
244       CHECK(addressable_devices_[idx] == nullptr) << idx;
245       addressable_devices_[idx] = device.get();
246     }
247     device->SetClient(this);
248   }
249   for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
250     CHECK(addressable_devices_[idx] != nullptr) << idx;
251   }
252 }
253 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const254 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
255     int num_replicas, int num_partitions) const {
256   return client_->backend().computation_placer()->AssignDevices(num_replicas,
257                                                                 num_partitions);
258 }
259 
260 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()261 PjRtStreamExecutorClient::GetHloCostAnalysis() {
262   return absl::make_unique<HloCostAnalysis>(
263       client_->backend().compiler()->ShapeSizeBytesFunction());
264 }
265 
266 namespace {
267 
268 // Ensures that it is safe to deallocate any buffers that have been enqueued in
269 // an operation on stream. Called only in rare error cases that are triggered
270 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)271 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
272   switch (local_device->allocation_model()) {
273     case LocalDeviceState::kAsynchronous:
274       // We can safely deallocate any dangling buffers immediately. NOTE: this
275       // assumes that any buffers enqueued on stream are local to stream's
276       // executor, and manual action may be needed if that condition is not met.
277       break;
278 
279     case LocalDeviceState::kComputeSynchronized:
280       // This will stall computation but that's ok in this very rare error
281       // case.
282       if (stream != local_device->compute_stream()) {
283         local_device->compute_stream()->ThenWaitFor(stream);
284       }
285       break;
286 
287     case LocalDeviceState::kSynchronous:
288       // This will stall the calling thread but that's ok in this very rare
289       // error case. If the stall fails just crash, since we have no other
290       // way to synchronize.
291       TF_CHECK_OK(stream->BlockHostUntilDone());
292       break;
293   }
294 }
295 
296 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
297 // a stream, to ensure that the buffer will be kept alive until its use on that
298 // stream is complete.
299 //
300 //   device_buffer:              the buffer that was enqueued.
301 //   buffer_local_device:        the device the buffer was allocated on.
302 //   stream_local_device:        the device that manages usage_stream.
303 //   event:                      an event that was recorded on usage_stream
304 //                               after the usage of device_buffer was enqueued.
305 //   usage_stream:               the stream the operation using device_buffer
306 //                               was enqueued on.
307 //   prefer_to_retain_reference: relevant only for the compute synchronous
308 //                               allocation model. If true, retain a reference
309 //                               to device_buffer until after the operation
310 //                               completes. If false then the compute stream
311 //                               will have to be synchronized past event before
312 //                               device_buffer can be freed.
313 //
314 // prefer_to_retain_reference encodes a heuristic set by the caller for the
315 // compute synchronous model:
316 //
317 // Generally when a buffer is the destination of a copy to a device, it will
318 // subsequently be used on the device's compute stream before being freed. In
319 // that case, there is no need to retain a reference to the buffer. If the
320 // buffer is freed before being used on the compute stream, the free will be
321 // delayed until the host knows that event has completed, but this is expected
322 // to be uncommon.
323 //
324 // When a buffer is the source of a copy from a device, we need to either retain
325 // a reference to the buffer until the copy completes or serialize the compute
326 // stream behind the copy. It is often better to retain a reference since while
327 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference)328 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
329                  LocalDeviceState* buffer_local_device,
330                  LocalDeviceState* stream_local_device,
331                  std::shared_ptr<BufferSequencingEvent> event,
332                  se::Stream* usage_stream, bool prefer_to_retain_reference) {
333   bool retain_buffer_until_completion =
334       // If the buffer wasn't allocated on the same device as the stream, always
335       // retain a reference.
336       (stream_local_device != buffer_local_device) ||
337       // In the synchronous allocation model, always retain a reference.
338       (stream_local_device->allocation_model() ==
339        LocalDeviceState::kSynchronous) ||
340       // In the compute synchronous model, use the caller's heuristic.
341       (stream_local_device->allocation_model() ==
342            LocalDeviceState::kComputeSynchronized &&
343        prefer_to_retain_reference);
344   if (retain_buffer_until_completion) {
345     buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
346   }
347   device_buffer.ConvertUsageHold(usage_stream, event,
348                                  retain_buffer_until_completion);
349 }
350 
351 // Allocates the device buffers for a buffer that will be used as the
352 // destination of a copy, either from the host or another device. copy_stream
353 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
354 // buffer is a tuple then the tuple tables are allocated, and all necessary
355 // synchronization for them is dealt with, before the buffer is returned.
356 //
357 // It is safe to delete the returned PjRtBuffer without further
358 // synchronization if an error occurs before the buffer is used.
359 //
360 // The caller may optionally provide a definition event to be recorded in
361 // the buffer.
362 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)363 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
364     const Shape& on_host_shape, PjRtDevice* device,
365     LocalDeviceState* local_device, se::Stream* copy_stream,
366     bool is_uninitialized_create, PjRtClient* client,
367     std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
368   if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
369     return InvalidArgument("Can't make a buffer from an empty tuple");
370   }
371 
372   auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
373   TransferManager* transfer_manager =
374       se_client->client()->backend().transfer_manager();
375   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
376                       transfer_manager->AllocateScopedShapedBuffer(
377                           on_host_shape, se_client->allocator(),
378                           local_device->device_ordinal()));
379   if (local_device->allocation_model() ==
380       LocalDeviceState::kComputeSynchronized) {
381     if (copy_stream == nullptr) {
382       CHECK(is_uninitialized_create);
383     } else {
384       copy_stream->ThenWaitFor(local_device->compute_stream());
385     }
386   } else {
387     DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
388         local_device->compute_stream()->parent(), dst_buffer));
389   }
390   Shape on_device_shape = dst_buffer.on_device_shape();
391 
392   absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
393       definition_events;
394   if (is_uninitialized_create) {
395     // There is not going to be any copy into the buffer so in general we don't
396     // need a definition event.
397     if (local_device->allocation_model() ==
398         LocalDeviceState::kComputeSynchronized) {
399       // The allocation is not valid until the compute stream passes this point,
400       // so add a definition event in the compute stream.
401       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
402       TF_ASSIGN_OR_RETURN(EventPool::Handle event,
403                           local_device->event_pool().ThenAllocateAndRecordEvent(
404                               local_device->compute_stream()));
405       definition_events.back()->SetSequencingEvent(
406           std::move(event), local_device->compute_stream());
407     }
408     // if the caller provided a definition event then we record that.
409     if (definition_event) {
410       definition_events.emplace_back(definition_event);
411     }
412   } else {
413     // We have at least one definition event, for the copy completing to
414     // the device buffers.
415     if (definition_event) {
416       definition_events.emplace_back(definition_event);
417     } else {
418       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
419     }
420   }
421   se::Stream* tuple_table_stream = local_device->host_to_device_stream();
422   if (on_device_shape.IsTuple()) {
423     // We also need to copy the tuple tables, so we'll have an additional
424     // definition event for that copy to complete.
425     if (tuple_table_stream != copy_stream) {
426       if (local_device->allocation_model() ==
427           LocalDeviceState::kComputeSynchronized) {
428         tuple_table_stream->ThenWaitFor(local_device->compute_stream());
429       } else {
430         DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
431             local_device->compute_stream()->parent(), dst_buffer));
432       }
433     }
434 
435     TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
436         tuple_table_stream, dst_buffer));
437     // CAUTION: From this point onwards we need to be careful about returning
438     // from error cases because we have started a transfer and must not allow
439     // dst_buffer to be freed too soon in the non-async allocation models.
440 
441     definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
442     StatusOr<EventPool::Handle> event_or =
443         local_device->event_pool().ThenAllocateAndRecordEvent(
444             tuple_table_stream);
445     if (!event_or.ok()) {
446       StallStreamOnError(local_device, tuple_table_stream);
447       return event_or.status();
448     }
449     definition_events.back()->SetSequencingEvent(event_or.ConsumeValueOrDie(),
450                                                  tuple_table_stream);
451   }
452   std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
453       TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
454                                                   definition_events);
455 
456   auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
457       on_device_shape, std::move(dst_device_buffer), client, device);
458 
459   if (on_device_shape.IsTuple()) {
460     // Add a usage hold for the tuple table write and immediately convert it to
461     // the appropriate form of synchronization. prefer_to_retain_reference=false
462     // means don't retain a memory reference until the transfer is complete when
463     // using the ComputeSynchronized allocation model. This is a heuristic
464     // because in the common case destination buffers will be used on the
465     // compute stream and therefore don't require any synchronization before
466     // being freed. If the buffer is allocated and never used, the free will
467     // take longer and this is assumed to be ok.
468     RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
469                 definition_events.back(), tuple_table_stream,
470                 /*prefer_to_retain_reference=*/false);
471   }
472 
473   return py_buffer;
474 }
475 
476 // Adds necessary synchronization after a copy has been enqueued to a buffer.
477 // definition_event was added when the buffer was allocated, but has not yet
478 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)479 Status AddDestinationBufferSynchronization(
480     LocalDeviceState* local_device,
481     PjRtStreamExecutorBuffer::ScopedHold device_buffer,
482     std::shared_ptr<BufferSequencingEvent> definition_event,
483     se::Stream* copy_stream) {
484   StatusOr<EventPool::Handle> event_or =
485       local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
486   if (!event_or.ok()) {
487     StallStreamOnError(local_device, copy_stream);
488     return event_or.status();
489   }
490   definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(),
491                                        copy_stream);
492   // prefer_to_retain_reference=false means don't retain a memory reference
493   // until the transfer is complete when using the ComputeSynchronized
494   // allocation model. This is a heuristic because in the common case
495   // destination buffers will be used on the compute stream and therefore don't
496   // require any synchronization before being freed. If the buffer is allocated
497   // and never used, the free will take longer and this is assumed to be ok.
498   RecordUsage(std::move(device_buffer), local_device, local_device,
499               definition_event, copy_stream,
500               /*prefer_to_retain_reference=*/false);
501   return Status::OK();
502 }
503 
504 }  // namespace
505 
~ScopedHold()506 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
507   if (ok()) {
508     parent_->DropHold(type_, buffer().get());
509   }
510 }
511 
ScopedHold(ScopedHold && other)512 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
513     : parent_(other.parent_),
514       type_(other.type_),
515       state_(other.state_),
516       status_(std::move(other.status_)),
517       buffer_(std::move(other.buffer_)) {
518   // Preserve the invariant that status is invalid if buffer == nullptr.
519   other.SetState(kMoved);
520 }
521 
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)522 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
523     StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
524   CHECK(!ok());
525   if (buffer_or.ok()) {
526     buffer_ = buffer_or.ValueOrDie();
527     SetState(kValid);
528   } else {
529     status_ = buffer_or.status();
530     buffer_ = nullptr;
531     SetState(kError);
532   }
533   // Check the invariant holds.
534   CHECK(!ok() || buffer_ != nullptr);
535 }
536 
537 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()538 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
539   CHECK(ok());
540   ForClosure for_closure(parent_, type_, state_, std::move(status_),
541                          std::move(buffer_));
542   SetState(kReleased);
543   return for_closure;
544 }
545 
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)546 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
547     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
548     bool reference_held) {
549   CHECK(ok());
550   CHECK_EQ(type_, kUsage);
551   parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
552                             reference_held);
553   SetState(kConverted);
554 }
555 
ConfirmDonation()556 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
557   CHECK(ok());
558   CHECK_EQ(type_, kDonation);
559   parent_->ConfirmDonation(buffer().get());
560   SetState(kDonated);
561 }
562 
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const563 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
564     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
565     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
566     ExecutionInput* execution_input,
567     se::DeviceMemoryAllocator* allocator) const {
568   CHECK(ok());
569   if (type_ == kDonation) {
570     buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
571   } else {
572     CHECK_EQ(type_, kUsage);
573     buffer()->AddToInputAsImmutable(iterator, end);
574   }
575 }
576 
IsOnCpu() const577 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
578   return client()->platform_id() == kCpuId;
579 }
580 
581 namespace {
582 
583 // Implements PjRtBuffer::ExternalReference as a wrapped
584 // ScopedHold::kExternalReference.
585 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
586  public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)587   explicit ScopedHoldAsExternalReference(
588       PjRtStreamExecutorBuffer::ScopedHold hold)
589       : external_reference_(std::move(hold)) {
590     CHECK(external_reference_.type() ==
591           PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
592     data_ptr_ = external_reference_->device_memory().front().opaque();
593   }
594 
595   ~ScopedHoldAsExternalReference() override = default;
596 
597  private:
598   PjRtStreamExecutorBuffer::ScopedHold external_reference_;
599 };
600 
601 }  // namespace
602 
603 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()604 PjRtStreamExecutorBuffer::AcquireExternalReference() {
605   ScopedHold hold = GetBufferWithExternalReference();
606   Status hold_status = hold.status();
607   if (!hold_status.ok()) return hold_status;
608   return std::unique_ptr<ExternalReference>(
609       std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
610 }
611 
612 class TrackedDeviceBufferExternalReference
613     : public PjRtBuffer::ExternalReference {
614  public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)615   explicit TrackedDeviceBufferExternalReference(
616       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
617       : tracked_device_buffer_(std::move(tracked_device_buffer)) {
618     data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
619   }
620 
621   ~TrackedDeviceBufferExternalReference() override = default;
622 
623  private:
624   std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
625 };
626 
627 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)628 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
629     bool wait_for_operations_to_complete) {
630   if (on_device_shape_.IsTuple()) {
631     return InvalidArgument(
632         "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
633   }
634   TF_ASSIGN_OR_RETURN(
635       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
636       Release(wait_for_operations_to_complete));
637 
638   std::unique_ptr<PjRtBuffer::ExternalReference> ref;
639   if (tracked_device_buffer) {
640     ref = std::make_unique<TrackedDeviceBufferExternalReference>(
641         std::move(tracked_device_buffer));
642   }
643   return ref;
644 }
645 
646 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,const Shape & shape,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)647 PjRtStreamExecutorClient::BufferFromHostBuffer(
648     const void* data, const Shape& shape,
649     HostBufferSemantics host_buffer_semantics,
650     std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
651   tensorflow::profiler::TraceMe traceme(
652       "PjRtStreamExecutorClient::BufferFromHostBuffer");
653   VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
654           << shape.ToString() << " device: " << device->DebugString();
655   if (shape.IsTuple()) {
656     return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
657   }
658   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
659                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
660                           ->GetLocalDeviceState());
661   int64 size = ShapeUtil::ByteSizeOf(shape);
662 
663   TransferManager* transfer_manager = client()->backend().transfer_manager();
664   TF_ASSIGN_OR_RETURN(Shape compact_shape,
665                       transfer_manager->ChooseCompactLayoutForShape(shape));
666 
667   // The CPU platform is special because the "host" and the "device" are in the
668   // same memory space. If the input shape is in the correct layout and we don't
669   // want to defer the copy onto a thread, we can use the following fast
670   // path.
671   bool is_cpu_platform =
672       local_device->executor()->platform()->id() == se::host::kHostPlatformId;
673   if (is_cpu_platform) {
674     // If we are on the host platform and the input buffer is sufficiently
675     // aligned, we can simply point to the input array's data without any
676     // further copies. At the time of writing we require a 16-byte alignment
677     // because XLA may generate code which requires it.
678     bool can_use_zero_copy =
679         host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
680         ((absl::bit_cast<std::uintptr_t>(data) &
681           (cpu_function_runtime::kMinAlign - 1)) == 0);
682     if (shape.layout() == compact_shape.layout() &&
683         (host_buffer_semantics ==
684              HostBufferSemantics::kImmutableOnlyDuringCall ||
685          can_use_zero_copy)) {
686       std::function<void()> on_delete_callback;
687       se::DeviceMemoryBase buffer;
688       // If we are on the host platform and the input buffer is sufficiently
689       // aligned, we can simply point to the input array's data without any
690       // further copies. At the time of writing we require a 16-byte alignment
691       // because XLA may generate code which requires it.
692       if (can_use_zero_copy) {
693         on_delete_callback = std::move(on_done_with_host_buffer);
694         buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
695       } else {
696         void* staging_buffer = host_memory_allocator()->AllocateRaw(
697             cpu_function_runtime::kMinAlign, size);
698         buffer = se::DeviceMemoryBase(staging_buffer, size);
699         std::memcpy(staging_buffer, data, size);
700         if (on_done_with_host_buffer) {
701           on_done_with_host_buffer();
702         }
703         on_delete_callback = [staging_buffer, host_memory_allocator =
704                                                   host_memory_allocator()]() {
705           host_memory_allocator->DeallocateRaw(staging_buffer);
706         };
707       }
708       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
709           definition_events;
710       auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
711           /*allocator=*/nullptr, local_device->device_ordinal(),
712           std::initializer_list<se::DeviceMemoryBase>{buffer},
713           definition_events, std::move(on_delete_callback));
714       return std::unique_ptr<PjRtBuffer>(
715           std::make_unique<PjRtStreamExecutorBuffer>(
716               shape, std::move(device_buffer), this, device));
717     }
718   }
719 
720   TF_ASSIGN_OR_RETURN(
721       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
722       AllocateDestinationBuffer(compact_shape, device, local_device,
723                                 local_device->host_to_device_stream(),
724                                 /*is_uninitialized_create=*/false, this));
725 
726   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
727       py_buffer->GetBufferWithUsageHold());
728   CHECK(device_buffer.ok());
729 
730   // If necessary, allocate a host-side buffer for staging host-to-device
731   // transfers. On GPU this is a buffer in pinned memory.
732   std::shared_ptr<void> staging_buffer;
733   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
734       should_stage_host_to_device_transfers()) {
735     void* ptr = host_memory_allocator()->AllocateRaw(
736         tensorflow::Allocator::kAllocatorAlignment, size);
737     staging_buffer = std::shared_ptr<void>(
738         ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
739           host_memory_allocator->DeallocateRaw(ptr);
740         });
741   }
742 
743   // Copy the buffer into a staging buffer before returning control to the
744   // caller if the caller only guaranteed that the buffer is valid for the
745   // duration of the call. Otherwise, we stage (if necessary) on a separate
746   // thread.
747   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
748     std::memcpy(staging_buffer.get(), data, size);
749     if (on_done_with_host_buffer) {
750       on_done_with_host_buffer();
751       on_done_with_host_buffer = nullptr;
752     }
753     data = nullptr;
754   }
755 
756   // The host to device transfer is performed on a thread pool, mostly because
757   // it includes linearization that may be slow. It is OK to capture the
758   // py_buffer pointer because the py_buffer can't be deleted until all the
759   // usage holds have gone away.
760   // TODO(misard) assess if it would be preferable to introduce a heuristic to
761   // put the transfer into the calling thread for small literals.
762   auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
763                        data, size,
764                        movable_device_buffer{device_buffer.ToClosure()}, shape,
765                        py_buffer{py_buffer.get()},
766                        on_device_shape{py_buffer->on_device_shape()},
767                        staging_buffer{std::move(staging_buffer)},
768                        on_done_with_host_buffer{
769                            std::move(on_done_with_host_buffer)},
770                        host_buffer_semantics]() {
771     PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
772     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
773     // to report failures from a callback. However, the operations here are
774     // unlikely to fail and not recoverable even if we were to fail: DMAs to
775     // memory that has already been allocated, and a possible Event
776     // allocation.
777 
778     ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
779     // If applicable on the backend, stage the transfer via host memory
780     // allocated via the host_memory_allocator. On GPU, this is pinned
781     // memory.
782     if (staging_buffer) {
783       // If we didn't already copy the input buffer into the staging buffer,
784       // do so now.
785       if (host_buffer_semantics !=
786           HostBufferSemantics::kImmutableOnlyDuringCall) {
787         std::memcpy(staging_buffer.get(), data, size);
788       }
789       BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
790                                shape);
791       TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
792           local_device->host_to_device_stream(), literal, buffer));
793     } else {
794       BorrowingLiteral literal(static_cast<const char*>(data), shape);
795       // Otherwise, just transfer the literal.
796       TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
797           local_device->host_to_device_stream(), literal, buffer));
798     }
799 
800     std::shared_ptr<BufferSequencingEvent> event =
801         device_buffer->definition_events()[0];
802     TF_CHECK_OK(AddDestinationBufferSynchronization(
803         local_device, std::move(device_buffer), event,
804         local_device->host_to_device_stream()));
805 
806     local_device->callback_stream()->ThenWaitFor(
807         local_device->host_to_device_stream());
808     local_device->ThenExecuteOnCallbackThread(
809         local_device->callback_stream(),
810         [staging_buffer{std::move(staging_buffer)},
811          on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
812           if (on_done_with_host_buffer) {
813             on_done_with_host_buffer();
814           }
815         });
816   };
817   if (is_cpu_platform) {
818     // Using the thread_pool would be a double thread hop; the code
819     // already defers its work onto a stream (= thread on CPU).
820     transfer_h2d();
821   } else {
822     thread_pool()->Schedule(transfer_h2d);
823   }
824   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
825 }
826 
827 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)828 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
829                                                     PjRtDevice* device) {
830   return CreateUninitializedBuffer(shape, device, nullptr);
831 }
832 
833 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)834 PjRtStreamExecutorClient::CreateUninitializedBuffer(
835     const Shape& shape, PjRtDevice* device,
836     std::shared_ptr<BufferSequencingEvent> definition_event) {
837   tensorflow::profiler::TraceMe traceme(
838       "PjRtStreamExecutorClient::CreateUninitializedBuffer");
839   VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
840           << shape.ToString() << " device: " << device->DebugString();
841   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
842                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
843                           ->GetLocalDeviceState());
844 
845   TransferManager* transfer_manager = client()->backend().transfer_manager();
846   TF_ASSIGN_OR_RETURN(Shape compact_shape,
847                       transfer_manager->ChooseCompactLayoutForShape(shape));
848 
849   TF_ASSIGN_OR_RETURN(
850       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
851       AllocateDestinationBuffer(compact_shape, device, local_device,
852                                 /*copy_stream=*/nullptr,
853                                 /*is_uninitialized_create=*/true, this,
854                                 definition_event));
855   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
856 }
857 
858 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)859 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
860                                                 PjRtDevice* device) {
861   tensorflow::profiler::TraceMe traceme(
862       "PjRtStreamExecutorClient::BufferFromHostLiteral");
863   VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
864           << literal.shape().ToString() << " device: " << device->DebugString();
865   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
866                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
867                           ->GetLocalDeviceState());
868 
869   TransferManager* transfer_manager = client()->backend().transfer_manager();
870   TF_ASSIGN_OR_RETURN(
871       Shape compact_shape,
872       transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
873   TF_ASSIGN_OR_RETURN(
874       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
875       AllocateDestinationBuffer(compact_shape, device, local_device,
876                                 local_device->host_to_device_stream(),
877                                 /*is_uninitialized_create=*/false, this));
878 
879   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
880       py_buffer->GetBufferWithUsageHold());
881   CHECK(device_buffer.ok());
882 
883   // The host to device transfer is performed on a thread pool, mostly because
884   // it includes linearization that may be slow. It is OK to capture the
885   // py_buffer pointer because the py_buffer can't be deleted until all the
886   // usage holds have gone away.
887   // TODO(misard) assess if it would be preferable to introduce a heuristic to
888   // put the transfer into the calling thread for small literals.
889   auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
890                        movable_device_buffer{device_buffer.ToClosure()},
891                        literal, py_buffer{py_buffer.get()},
892                        on_device_shape{py_buffer->on_device_shape()}]() {
893     PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
894     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
895     // to report failures from a callback. However, the operations here are
896     // unlikely to fail and not recoverable even if we were to fail: DMAs to
897     // memory that has already been allocated, and a possible Event
898     // allocation.
899 
900     se::Stream* h2d_stream = local_device->host_to_device_stream();
901     ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
902     TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
903         h2d_stream, literal, buffer));
904 
905     std::shared_ptr<BufferSequencingEvent> event =
906         device_buffer->definition_events()[0];
907     TF_CHECK_OK(AddDestinationBufferSynchronization(
908         local_device, std::move(device_buffer), event, h2d_stream));
909 
910     // This can sometimes catch the case where the literal memory has been
911     // freed before the H2D transfer was issued.
912     h2d_stream->RefreshStatus()
913         .IgnoreError();  // Can return error::Unimplemented
914     QCHECK(h2d_stream->ok());
915   };
916   thread_pool()->Schedule(transfer_h2d);
917   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
918 }
919 
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier && notifier)920 void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
921     absl::Span<const Shape> shapes, PjRtDevice* device,
922     PjRtCrossHostRecvNotifier&& notifier) {
923   if (shapes.empty()) {
924     notifier(InvalidArgument(
925         "shapes parameter empty in MakeCrossHostReceiveBuffers"));
926     return;
927   }
928 
929   auto local_device_or =
930       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
931           ->GetLocalDeviceState();
932   if (!local_device_or.ok()) {
933     notifier(local_device_or.status());
934     return;
935   }
936   LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
937   std::shared_ptr<BufferSequencingEvent> definition_event =
938       std::make_shared<BufferSequencingEvent>();
939   std::vector<std::unique_ptr<PjRtBuffer>> buffers;
940   buffers.reserve(shapes.size());
941   for (const auto& shape : shapes) {
942     StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
943         shape, device, local_device,
944         /*copy_stream=*/nullptr,
945         /*is_uninitialized_create=*/false, this, definition_event);
946     if (!buffer_or.ok()) {
947       notifier(buffer_or.status());
948       return;
949     }
950     buffers.push_back(buffer_or.ConsumeValueOrDie());
951   }
952 
953   EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
954                           std::move(notifier));
955 }
956 
957 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)958 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
959     void* device_ptr, const Shape& shape, PjRtDevice* device,
960     std::function<void()> on_delete_callback) {
961   se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
962   absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
963   auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
964       /*allocator=*/nullptr, device->local_hardware_id(),
965       std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
966       std::move(on_delete_callback));
967   return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
968       shape, std::move(device_buffer), this, device));
969 }
970 
971 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)972 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
973   // Only support infeed to local device.
974   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
975   return local_device->client()->TransferToInfeedLocal(
976       literal, local_device->device_ordinal());
977 }
978 
TransferFromOutfeed(MutableBorrowingLiteral literal)979 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
980     MutableBorrowingLiteral literal) {
981   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
982   return local_device->client()->TransferFromOutfeedLocal(
983       local_device->device_ordinal(), literal);
984 }
985 
LookupAddressableDevice(int local_hardware_id) const986 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
987     int local_hardware_id) const {
988   for (auto* device : addressable_devices_) {
989     if (local_hardware_id == device->local_hardware_id()) {
990       return device;
991     }
992   }
993   return InvalidArgument("No matching device found for local_hardware_id %d",
994                          local_hardware_id);
995 }
996 
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)997 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
998     Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
999     PjRtClient* client, PjRtDevice* device)
1000     : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1001       on_device_shape_(std::move(on_device_shape)),
1002       device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1003       device_buffer_(std::move(device_buffer)),
1004       donation_semaphore_(/*capacity=*/1) {
1005   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1006     holds_[i] = 0;
1007   }
1008 }
1009 
~PjRtStreamExecutorBuffer()1010 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1011   Delete();
1012   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1013     CHECK_EQ(holds_[i], 0);
1014   }
1015 }
1016 
OnDeviceSizeInBytes() const1017 int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
1018   return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1019       ->client()
1020       ->backend()
1021       .transfer_manager()
1022       ->GetByteSizeRequirement(on_device_shape_);
1023 }
1024 
WaitForOutstandingUsageHolds()1025 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1026   auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1027     return holds_[ScopedHold::kUsage] == 0;
1028   };
1029   mu_.Await(absl::Condition(&not_in_usage_hold));
1030 }
1031 
WaitForOutstandingDonationHold()1032 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1033   auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1034     return holds_[ScopedHold::kDonation] == 0;
1035   };
1036   mu_.Await(absl::Condition(&not_in_donation_hold));
1037 }
1038 
1039 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1040 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1041   tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1042   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1043   TrackedDeviceBuffer::StreamAndEventContainer events;
1044   {
1045     absl::MutexLock lock(&mu_);
1046     // We first wait for a donation hold to complete if there is one in
1047     // progress. If the donation succeeds via ConfirmDonation() then it will
1048     // set device_buffer_ to nullptr before returning to this thread.
1049     WaitForOutstandingDonationHold();
1050     if (device_buffer_ == nullptr) {
1051       return std::shared_ptr<TrackedDeviceBuffer>();
1052     }
1053     // Set device_buffer_ to null now so that no other
1054     // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1055     // below.
1056     std::swap(device_buffer_, device_buffer);
1057     WaitForOutstandingUsageHolds();
1058     // Now that all holds have completed and no more can be added, we can get
1059     // the final set of usage events.
1060     events = device_buffer->LockUseAndTransferUsageEvents();
1061   }
1062   LocalDeviceState* local_device_state =
1063       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1064           ->local_device_state();
1065   if (wait_for_operations_to_complete) {
1066     // Block the host until all usage events have completed. Usage events
1067     // dominate definition events, so this also waits for the buffer to be
1068     // defined.
1069     std::unique_ptr<se::Stream> stream;
1070     for (const auto& stream_and_event : events) {
1071       if (!stream_and_event.event->IsComplete()) {
1072         if (stream == nullptr) {
1073           stream = local_device_state->BorrowStreamFromPool();
1074         }
1075         stream_and_event.event->WaitForEventOnStream(stream.get());
1076       }
1077     }
1078     if (stream != nullptr) {
1079       TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1080       local_device_state->ReturnStreamToPool(std::move(stream));
1081     }
1082   } else {
1083     if (local_device_state->allocation_model() ==
1084         LocalDeviceState::kComputeSynchronized) {
1085       std::unique_ptr<se::Stream> block_stream;
1086       for (const auto& stream_and_event : events) {
1087         // We only need to do something for events that didn't already acquire a
1088         // reference to the buffer, and also which the compute stream didn't
1089         // already wait for. Based on our heuristics this rare case should only
1090         // occur when a buffer was copied to a device and then never used there.
1091         // In that case we get a new stream and use it to hold onto a reference
1092         // to the buffer until the events are complete.
1093         if (!stream_and_event.reference_held &&
1094             !stream_and_event.event->DefinedOn(
1095                 local_device_state->compute_stream()) &&
1096             !stream_and_event.event->IsComplete()) {
1097           if (block_stream == nullptr) {
1098             block_stream = local_device_state->BorrowStreamFromPool();
1099           }
1100           stream_and_event.event->WaitForEventOnStream(block_stream.get());
1101         }
1102       }
1103       if (block_stream != nullptr) {
1104         se::Stream* block_stream_ptr = block_stream.release();
1105         local_device_state->ThenExecuteOnCallbackThread(
1106             block_stream_ptr,
1107             [device_buffer, block_stream_ptr, local_device_state]() {
1108               local_device_state->ReturnStreamToPool(
1109                   std::unique_ptr<se::Stream>(block_stream_ptr));
1110             });
1111       }
1112     }
1113   }
1114   return device_buffer;
1115 }
1116 
Delete()1117 void PjRtStreamExecutorBuffer::Delete() {
1118   // When wait_for_reads_to_complete is false, Release should never fail.
1119   TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1120 }
1121 
IsDeleted()1122 bool PjRtStreamExecutorBuffer::IsDeleted() {
1123   absl::MutexLock lock(&mu_);
1124   return device_buffer_ == nullptr;
1125 }
1126 
1127 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1128 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1129   if (type == ScopedHold::kDonation) {
1130     if (device_buffer_ == nullptr) {
1131       return InvalidArgument("Donation requested for invalid buffer");
1132     }
1133     if (holds_[ScopedHold::kExternalReference] > 0) {
1134       return InvalidArgument(
1135           "Donation requested for buffer with external reference");
1136     }
1137     // donation_semaphore_ was acquired in GetBufferWithHold so that only one
1138     // thread at a time can attempt to get a donation hold.
1139     CHECK_EQ(holds_[type], 0);
1140     // First add the donation hold.
1141     ++holds_[type];
1142     // Then wait for any usage holds to be dropped or converted. No new usage
1143     // holds can be added until we drop the donation hold so this wait will
1144     // complete eventually.
1145     WaitForOutstandingUsageHolds();
1146     // Because we added a donation hold, nobody could release the buffer while
1147     // we were waiting.
1148     CHECK(device_buffer_ != nullptr);
1149   } else {
1150     // If there is a donation hold in progress we have to wait before
1151     // acquiring any other kind of hold.
1152     WaitForOutstandingDonationHold();
1153     if (device_buffer_ == nullptr) {
1154       return InvalidArgument("Hold requested on deleted or donated buffer");
1155     } else {
1156       ++holds_[type];
1157     }
1158   }
1159   return device_buffer_;
1160 }
1161 
AcquireHoldLocked(ScopedHold * hold)1162 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1163   hold->Acquire(GetBufferForHoldLocked(hold->type()));
1164 }
1165 
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1166 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1167     TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1168     std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1169   absl::MutexLock lock(&mu_);
1170   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1171   buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1172   CHECK_GT(holds_[ScopedHold::kUsage], 0);
1173   --holds_[ScopedHold::kUsage];
1174 }
1175 
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1176 void PjRtStreamExecutorBuffer::ConfirmDonation(
1177     TrackedDeviceBuffer* device_buffer) {
1178   {
1179     absl::MutexLock lock(&mu_);
1180     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1181     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1182     CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1183     holds_[ScopedHold::kDonation] = 0;
1184     CHECK(device_buffer_.get() == device_buffer);
1185     // As a sanity check ensure no more usage events can be added to the buffer.
1186     device_buffer->LockUseAndTransferUsageEvents();
1187     // Give up ownership of the device memory so we don't free it when the last
1188     // reference to device_buffer_ goes away.
1189     device_buffer->ReleaseDeviceMemory();
1190     // Make *this invalid so it can't be used again. Any threads blocking in
1191     // Release or GetBufferWithHold will see an invalid buffer and return.
1192     device_buffer_.reset();
1193   }
1194   // Unblock another thread, if any, trying to get a donation hold.
1195   donation_semaphore_.Release(1);
1196 }
1197 
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1198 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1199                                         TrackedDeviceBuffer* buffer) {
1200   absl::MutexLock lock(&mu_);
1201   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1202   CHECK_GT(holds_[type], 0);
1203   --holds_[type];
1204   if (type == ScopedHold::kDonation) {
1205     CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1206     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1207     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1208     donation_semaphore_.Release(1);
1209   }
1210 }
1211 
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)1212 void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
1213                                          std::function<void(Status)> on_ready) {
1214   if (IsEmptyTuple()) {
1215     on_ready(InvalidArgument("ToLiteral called on empty tuple"));
1216     return;
1217   }
1218   LocalDeviceState* local_device =
1219       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1220           ->local_device_state();
1221   se::Stream* stream = local_device->GetDeviceToHostStream();
1222   ScopedHold device_buffer(this, ScopedHold::kUsage);
1223   {
1224     absl::MutexLock lock(&mu_);
1225     // We can't perform any other action while a donation hold is in progress.
1226     WaitForOutstandingDonationHold();
1227     if (device_buffer_ == nullptr) {
1228       on_ready(InvalidArgument(
1229           "CopyToHostAsync() called on deleted or donated buffer"));
1230       return;
1231     }
1232     AcquireHoldLocked(&device_buffer);
1233   }
1234 
1235   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1236   ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1237   StatusOr<EventPool::Handle> event_or =
1238       local_device->event_pool().AllocateEvent(stream->parent());
1239   if (!event_or.ok()) {
1240     on_ready(event_or.status());
1241     return;
1242   }
1243   tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1244       ->client()
1245       ->backend()
1246       .transfer_manager()
1247       ->TransferLiteralFromDevice(stream, shaped_buffer, literal,
1248                                   std::move(on_ready));
1249 
1250   auto usage_event = std::make_shared<BufferSequencingEvent>();
1251   local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1252   usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1253   // When using the ComputeSynchronized allocation model, retain a reference to
1254   // the device_buffer until the copy completes, to ensure that the buffer isn't
1255   // deleted or donated while it is still in use. The choice of retaining a
1256   // reference at the host is a heuristic; the alternative is to ensure, before
1257   // freeing the buffer, that the compute stream is synchronized past the
1258   // transfer, but it seems better to hold onto the buffer too long than to
1259   // stall the compute stream, particularly since the overwhelmingly common
1260   // use case of CopyToHostAsync will hold onto the reference long enough to
1261   // read the buffer in a subsequent call to ToLiteral.
1262   RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1263               stream,
1264               /*prefer_to_retain_reference=*/true);
1265 }
1266 
AsShapedBuffer() const1267 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1268   absl::MutexLock lock(&mu_);
1269   if (device_buffer_ == nullptr) {
1270     return InvalidArgument(
1271         "Attempted to fetch value of invalid/deleted buffer.");
1272   }
1273   return device_buffer_->AsShapedBuffer(on_device_shape_);
1274 }
1275 
1276 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1277 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1278   if (type == ScopedHold::kDonation) {
1279     // Ensure that at most one donation hold can be in progress at a time.
1280     donation_semaphore_.Acquire(1);
1281   }
1282   absl::MutexLock lock(&mu_);
1283   ScopedHold hold(this, type);
1284   AcquireHoldLocked(&hold);
1285   if (type == ScopedHold::kDonation && !hold.ok()) {
1286     donation_semaphore_.Release(1);
1287   }
1288   return hold;
1289 }
1290 
1291 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1292                    std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1293 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1294     PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1295     LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1296     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1297   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1298                       AllocateDestinationBuffer(
1299                           ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1300                           dst_device, dst_local_device, transfer_stream,
1301                           /*is_uninitialized_create=*/false, client_));
1302 
1303   TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1304 
1305   WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1306 
1307   ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1308   CHECK(dst_device_buffer.ok());
1309   ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1310 
1311   // Copy the leaf buffers.
1312   StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1313       [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1314     for (const auto& leaf : src_buffer.buffers().leaves()) {
1315       const ShapeIndex& index = leaf.first;
1316       const se::DeviceMemoryBase& input_buffer = leaf.second;
1317       const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1318       TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1319           << "input: " << input_buffer.size()
1320           << " output: " << output_buffer.size();
1321       if (input_buffer.size() != 0) {
1322         TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1323             transfer_stream, dst_local_device->compute_stream(), input_buffer,
1324             output_buffer));
1325       }
1326     }
1327     std::shared_ptr<BufferSequencingEvent> event =
1328         dst_device_buffer->definition_events()[0];
1329     TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1330         transfer_local_device, std::move(dst_device_buffer), event,
1331         transfer_stream));
1332     return event;
1333   }();
1334   if (!copy_event_or.ok()) {
1335     StallStreamOnError(transfer_local_device, transfer_stream);
1336     if (transfer_local_device == dst_local_device) {
1337       // Some copies may have been enqueued before the error was returned, and
1338       // StallStreamOnError only makes sure the destination device is ok, so
1339       // make sure that the src buffer remains valid until after any transfers
1340       // have completed.
1341       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1342           ->local_device_state()
1343           ->ThenRelease(transfer_stream, src_device_buffer);
1344     }
1345     return copy_event_or.status();
1346   }
1347 
1348   return std::pair<std::unique_ptr<PjRtBuffer>,
1349                    std::shared_ptr<BufferSequencingEvent>>(
1350       std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1351       copy_event_or.ConsumeValueOrDie());
1352 }
1353 
CopyToDevice(PjRtDevice * dst_device)1354 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1355     PjRtDevice* dst_device) {
1356   tensorflow::profiler::TraceMe traceme(
1357       "PjRtStreamExecutorBuffer::CopyToDevice");
1358   if (dst_device == device_) {
1359     return InvalidArgument(
1360         "CopyToDevice cannot accept the same source and destination devices");
1361   }
1362 
1363   // Copying across PjRtClients involves a copy through the host.
1364   if (dst_device->client() != client_) {
1365     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
1366     // Avoid use-after-free on `literal` due to unsequenced move and use.
1367     Literal* literal_pointer = literal.get();
1368     return dst_device->client()->BufferFromHostBuffer(
1369         literal_pointer->untyped_data(), literal_pointer->shape(),
1370         PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1371         [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1372   }
1373 
1374   TF_ASSIGN_OR_RETURN(
1375       LocalDeviceState * dst_local_device,
1376       tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1377           ->GetLocalDeviceState());
1378   LocalDeviceState* transfer_local_device =
1379       tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1380               ->EnqueueD2DTransfersOnSrcStream()
1381           ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1382                 ->local_device_state()
1383           : dst_local_device;
1384   CHECK_EQ(dst_local_device->allocation_model(),
1385            transfer_local_device->allocation_model());
1386 
1387   se::Stream* transfer_stream =
1388       transfer_local_device->GetDeviceToDeviceStream();
1389 
1390   ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1391   {
1392     absl::MutexLock lock(&mu_);
1393     // We can't perform any other action while a donation hold is in progress.
1394     WaitForOutstandingDonationHold();
1395     if (device_buffer_ == nullptr) {
1396       return InvalidArgument(
1397           "CopyToDevice called on deleted or donated buffer");
1398     }
1399     AcquireHoldLocked(&src_device_buffer);
1400   }
1401 
1402   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1403                      std::shared_ptr<BufferSequencingEvent>>>
1404       buffer_and_event_or = CopyToDeviceHelper(
1405           dst_device, dst_local_device, transfer_local_device, transfer_stream,
1406           src_device_buffer.buffer());
1407   if (!buffer_and_event_or.ok()) {
1408     return buffer_and_event_or.status();
1409   }
1410 
1411   auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1412   std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1413   std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1414 
1415   // prefer_to_retain_reference=*/true means that, when using the
1416   // ComputeSynchronized allocation model, retain a reference to the
1417   // src_device_buffer until the copy completes. This is a heuristic; the
1418   // alternative is to ensure, before freeing the buffer, that the compute
1419   // stream is synchronized past the transfer, but it seems better to hold onto
1420   // the buffer too long than to stall the compute stream.
1421   RecordUsage(std::move(src_device_buffer),
1422               tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1423                   ->local_device_state(),
1424               transfer_local_device, event, transfer_stream,
1425               /*prefer_to_retain_reference=*/true);
1426 
1427   return std::move(buffer);
1428 }
1429 
CopyToRemoteDevice(absl::string_view serialized_descriptor)1430 Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1431     absl::string_view serialized_descriptor) {
1432   return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1433       ->CopyToRemoteDevice(this, serialized_descriptor);
1434 }
1435 
BlockHostUntilReady()1436 Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
1437   tensorflow::profiler::TraceMe traceme(
1438       "PjRtStreamExecutorBuffer::BlockHostUntilReady");
1439   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1440   {
1441     absl::MutexLock lock(&mu_);
1442     if (device_buffer_ == nullptr) {
1443       return InvalidArgument(
1444           "BlockHostUntilReady() called on deleted or donated buffer");
1445     }
1446     device_buffer = device_buffer_;
1447   }
1448   LocalDeviceState* local_device_state =
1449       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
1450           ->local_device_state();
1451   std::unique_ptr<se::Stream> stream;
1452   for (auto& event : device_buffer->definition_events()) {
1453     if (!event->IsComplete()) {
1454       if (stream == nullptr) {
1455         stream = local_device_state->BorrowStreamFromPool();
1456       }
1457       event->WaitForEventOnStream(stream.get());
1458     }
1459   }
1460   if (stream != nullptr) {
1461     TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1462     local_device_state->ReturnStreamToPool(std::move(stream));
1463   }
1464   return Status::OK();
1465 }
1466 
1467 namespace {
1468 
1469 // Helper struct for the tuple that is transiently constructed to hold the
1470 // arguments of an execution.
1471 struct TupleHandle {
1472   // The ExecutionInput describing the tuple.
1473   ExecutionInput execution_input;
1474   // A definition event that has been recorded on the host_to_device stream
1475   // after the tuple table transfer.
1476   std::shared_ptr<BufferSequencingEvent> event;
1477 };
1478 
1479 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtClient * client,LocalDeviceState * local_device,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1480 StatusOr<TupleHandle> MakeTupleHelper(
1481     PjRtClient* client, LocalDeviceState* local_device,
1482     absl::Span<PjRtBuffer* const> py_buffers,
1483     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1484     int device_ordinal) {
1485   std::vector<Shape> host_shapes;
1486   std::vector<Shape> device_shapes;
1487   host_shapes.reserve(py_buffers.size());
1488   device_shapes.reserve(py_buffers.size());
1489   for (const PjRtBuffer* buffer : py_buffers) {
1490     device_shapes.push_back(buffer->on_device_shape());
1491   }
1492   Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
1493 
1494   se::DeviceMemoryAllocator* allocator =
1495       tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
1496   TransferManager* transfer_manager =
1497       tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
1498           ->client()
1499           ->backend()
1500           .transfer_manager();
1501   se::Stream* stream = local_device->host_to_device_stream();
1502   TF_ASSIGN_OR_RETURN(
1503       se::OwningDeviceMemory root_table_memory,
1504       allocator->Allocate(
1505           device_ordinal,
1506           transfer_manager->GetByteSizeRequirement(on_device_shape)));
1507 
1508   if (local_device->allocation_model() ==
1509       LocalDeviceState::kComputeSynchronized) {
1510     stream->ThenWaitFor(local_device->compute_stream());
1511   } else {
1512     DCHECK(transfer_manager->CanBufferBeAccessedNow(
1513         local_device->compute_stream()->parent(), root_table_memory.cref()));
1514   }
1515 
1516   ExecutionInput execution_input(on_device_shape);
1517   ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1518       execution_input.MutableBuffers()->begin();
1519   ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1520       execution_input.MutableBuffers()->end();
1521   // First set the root tuple table which is the first buffer in the ShapeTree.
1522   execution_input.SetBuffer(
1523       input_iterator->first,
1524       MaybeOwningDeviceMemory(std::move(root_table_memory)));
1525   ++input_iterator;
1526   // Then set each sub-tuple in turn from the parameters.
1527   for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1528        device_buffers) {
1529     device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1530                              allocator);
1531   }
1532   CHECK(input_iterator == iterator_end);
1533 
1534   TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1535       stream, execution_input.Buffers()));
1536   StatusOr<EventPool::Handle> event_or =
1537       local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1538   if (!event_or.ok()) {
1539     StallStreamOnError(local_device, stream);
1540     return event_or.status();
1541   }
1542 
1543   auto transfer_event = std::make_shared<BufferSequencingEvent>();
1544   transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1545   return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1546 }
1547 
1548 // Converts a ScopedShapedBuffer returned from an execution into a
1549 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device)1550 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1551     ScopedShapedBuffer* result_buffer,
1552     std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1553     PjRtDevice* device, LocalDeviceState* local_device) {
1554   std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1555       TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1556                                                   {definition_event});
1557   auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
1558       result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1559   RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1560               definition_event, local_device->compute_stream(),
1561               /*prefer_to_retain_reference=*/false);
1562   return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1563 }
1564 }  // namespace
1565 
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1566 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1567     std::vector<std::unique_ptr<LocalExecutable>> executables,
1568     bool parameter_is_tupled_arguments,
1569     std::shared_ptr<DeviceAssignment> device_assignment,
1570     std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1571     std::vector<PjRtDevice*> addressable_devices,
1572     PjRtStreamExecutorClient* client)
1573     : client_(client),
1574       device_assignment_(std::move(device_assignment)),
1575       parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1576       addressable_device_logical_ids_(
1577           std::move(addressable_device_logical_ids)),
1578       addressable_devices_(std::move(addressable_devices)) {
1579   executables_.reserve(executables.size());
1580   for (auto& executable : executables) {
1581     executables_.emplace_back(std::move(executable));
1582   }
1583 
1584   int num_partitions;
1585   if (device_assignment_ == nullptr) {
1586     // This must go after `executables_` is initialized.
1587     VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
1588     num_partitions = 1;
1589     CHECK(addressable_devices_.empty());
1590   } else {
1591     // This must go after `executables_` is initialized.
1592     VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
1593             << device_assignment_->ToString();
1594     CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1595     CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1596         << "Inconsistent local device count.";
1597     num_partitions = device_assignment_->computation_count();
1598   }
1599 
1600   // SPMD sharding produces a single executable for multiple partitions.
1601   if (executables_.size() > 1) {
1602     CHECK_EQ(num_partitions, executables_.size())
1603         << "Number of executables " << executables_.size()
1604         << " did not match number of partitions " << num_partitions;
1605   }
1606 }
1607 
SetUpDonation(bool tuple_inputs)1608 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1609   parameters_that_must_be_donated_.reserve(executables_.size());
1610   for (auto& executable : executables_) {
1611     TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
1612                         GetParametersThatMustBeDonated(
1613                             executable->executable()->module(), tuple_inputs));
1614     parameters_that_must_be_donated_.emplace_back(
1615         std::move(parameters_to_donate));
1616   }
1617   return Status::OK();
1618 }
1619 
name() const1620 absl::string_view PjRtStreamExecutorExecutable::name() const {
1621   Executable* executable = executables_[0]->executable();
1622   if (executable->has_module()) {
1623     return executable->module().name();
1624   } else {
1625     return "<unknown executable>";
1626   }
1627 }
1628 
MustDonateParameter(int executable_idx,int parameter) const1629 bool PjRtStreamExecutorExecutable::MustDonateParameter(int executable_idx,
1630                                                        int parameter) const {
1631   return parameters_that_must_be_donated_[executable_idx].contains(parameter);
1632 }
1633 
1634 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1635 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1636     int device_ordinal, const ExecuteOptions& options,
1637     absl::Span<PjRtBuffer* const> argument_handles,
1638     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1639     absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1640   std::vector<ExecutionInput> execution_inputs;
1641   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1642   // Lift tuple_handle outside the conditional so that the event it returns is
1643   // not destroyed until after the loop below that waits on events.
1644   absl::optional<TupleHandle> tuple_handle;
1645   if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1646     TF_ASSIGN_OR_RETURN(tuple_handle,
1647                         MakeTupleHelper(client_, device_state, argument_handles,
1648                                         device_buffers, device_ordinal));
1649     events.insert(tuple_handle->event.get());
1650     execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1651   } else {
1652     execution_inputs.reserve(argument_handles.size());
1653     for (int i = 0; i < argument_handles.size(); ++i) {
1654       PjRtBuffer* handle = argument_handles[i];
1655 
1656       // Make an ExecutionInput from the device buffer.
1657       execution_inputs.emplace_back(handle->on_device_shape());
1658       ExecutionInput& execution_input = execution_inputs.back();
1659       ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1660           execution_input.MutableBuffers()->begin();
1661       ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1662           execution_input.MutableBuffers()->end();
1663       device_buffers[i].AddToInput(
1664           &input_iterator, iterator_end, &execution_input,
1665           tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
1666               ->allocator());
1667       CHECK(input_iterator == iterator_end);
1668     }
1669   }
1670 
1671   for (BufferSequencingEvent* event : events) {
1672     event->WaitForEventOnStream(device_state->compute_stream());
1673   }
1674 
1675   return execution_inputs;
1676 }
1677 
1678 // Enqueues a computation onto the compute stream. Each buffer returned in
1679 // device_buffers has a usage hold added that must be dropped on error or
1680 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment) const1681 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1682     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1683     int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1684     PjRtDevice* device,
1685     std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1686     std::shared_ptr<DeviceAssignment> device_assignment) const {
1687   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1688                            ->local_device_state()
1689                            ->device_ordinal();
1690   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1691   tensorflow::profiler::TraceMeConsumer activity(
1692       "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
1693       run_id.ToInt());
1694   VLOG(3) << "Replica " << replica << ", partition " << partition
1695           << " mapped to device ordinal for execution: " << device_ordinal;
1696 
1697   absl::flat_hash_set<BufferSequencingEvent*> events;
1698   device_buffers->reserve(argument_handles.size());
1699   for (int i = 0; i < argument_handles.size(); ++i) {
1700     auto* handle =
1701         tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1702     if (handle->device() != device) {
1703       return InvalidArgument(
1704           "Buffer passed to Execute() as argument %d to replica %d is on "
1705           "device %s, but replica is assigned to device %s.",
1706           i, replica, handle->device()->DebugString(), device->DebugString());
1707     }
1708     bool must_donate = MustDonateParameter(executable_idx, i);
1709     device_buffers->emplace_back(handle->GetBufferWithHold(
1710         must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1711                     : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1712     PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1713         device_buffers->back();
1714     if (!device_buffer.ok()) {
1715       return InvalidArgument(
1716           "Invalid buffer passed to Execute() as argument %d to replica %d: "
1717           "%s",
1718           i, replica, device_buffer.status().ToString());
1719     }
1720     // If we are trying to donate the buffer wait on the usage events as well
1721     // as the definition events to ensure that all reads have been completed
1722     // before the buffer is mutated. Usage holds are excluded during a donation
1723     // hold so we know that the set of usage events won't be modified while we
1724     // are enqueueing.
1725     GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1726                           &events);
1727   }
1728 
1729   if (options.arguments_are_tupled) {
1730     if (!parameter_is_tupled_arguments_) {
1731       return InvalidArgument(
1732           "Arguments may only be supplied as a tuple when the executable was "
1733           "compiled with a single tupled parameter");
1734     }
1735     if (argument_handles.size() != 1) {
1736       return InvalidArgument(
1737           "Option arguments_are_tupled was true but %d buffers were passed to "
1738           "execution",
1739           argument_handles.size());
1740     }
1741   }
1742 
1743   TF_ASSIGN_OR_RETURN(
1744       std::vector<ExecutionInput> execution_inputs,
1745       MakeExecutionInputsAndWaitForEvents(
1746           device_ordinal, options, argument_handles, *device_buffers, events));
1747 
1748   ExecutableRunOptions run_options;
1749   run_options.set_stream(device_state->compute_stream());
1750   run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1751   run_options.set_allocator(client_->allocator());
1752   run_options.set_intra_op_thread_pool(
1753       client_->client()->backend().eigen_intra_op_thread_pool_device());
1754   run_options.set_device_assignment(device_assignment.get());
1755   run_options.set_run_id(run_id);
1756   run_options.set_rng_seed(device_state->GetNewPrngSeed());
1757   run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1758   run_options.set_launch_id(options.launch_id);
1759   if (run_options.launch_id() != 0) {
1760     VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id();
1761   }
1762 
1763   // The choice of where we wait is arbitrary; the reason for the wait is
1764   // pacing to avoid problems such as memory fragmentation and running ahead
1765   // too far, not for correctness. Placing it before the executable launch
1766   // allows the inputs for the next executable to be fetched even if the
1767   // launch is delayed.
1768   auto compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1769       device_state->compute_semaphore().ScopedAcquire(1));
1770 
1771   StatusOr<ExecutionOutput> result_buffer_or_status =
1772       executables_[executable_idx]->RunAsync(std::move(execution_inputs),
1773                                              run_options);
1774 
1775   VLOG(1) << "Replica " << replica << " partition " << partition
1776           << " completed; ok=" << result_buffer_or_status.ok();
1777 
1778   if (!result_buffer_or_status.ok()) {
1779     return result_buffer_or_status.status();
1780   }
1781 
1782   if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1783     ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
1784     // If we used a transient tuple for the arguments we donated its root table
1785     // buffer. In that case, and/or if we donated any input buffers that were
1786     // not aliased, the donated buffers are going to be passed back to us via
1787     // the execution output. We need to ensure they aren't freed until after
1788     // execution completes. (Currently XLA does not support aliasing tuple
1789     // tables, so if any donated parameter is a tuple there will be donated but
1790     // unaliased buffers.)
1791     std::vector<se::OwningDeviceMemory> donated_memory =
1792         execution_output.ConsumeToBeReleased();
1793     absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
1794     donated_ptrs.reserve(donated_memory.size());
1795     for (se::OwningDeviceMemory& owning : donated_memory) {
1796       // Release the owning memory so we can pass it to the closure.
1797       donated_ptrs.push_back(owning.Release());
1798     }
1799     device_state->ThenExecuteOnCallbackThread(
1800         device_state->compute_stream(),
1801         [references{std::make_tuple(executables_[executable_idx],
1802                                     compute_reservation, device_assignment)},
1803          donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
1804          device_ordinal]() {
1805           for (const auto& ptr : donated_ptrs) {
1806             TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
1807           }
1808         });
1809   } else {
1810     // Any donated memory returned by the ExecutionOutput can be immediately
1811     // freed.
1812     device_state->ThenRelease(
1813         device_state->compute_stream(),
1814         std::make_tuple(executables_[executable_idx], compute_reservation,
1815                         device_assignment));
1816   }
1817 
1818   return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
1819 }
1820 
1821 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device) const1822 PjRtStreamExecutorExecutable::MakeOutputBuffers(
1823     int device_ordinal, const ExecuteOptions& options,
1824     ScopedShapedBuffer result_buffer,
1825     std::shared_ptr<BufferSequencingEvent> definition_event,
1826     PjRtDevice* device) const {
1827   std::vector<std::unique_ptr<PjRtBuffer>> outputs;
1828   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1829   if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
1830     int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
1831     outputs.reserve(tuple_count);
1832     // Take ownership of each of the output values, leaving only the root table
1833     // in result_buffer.
1834     for (int i = 0; i < tuple_count; ++i) {
1835       ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
1836       outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
1837                                            client_, device, device_state));
1838     }
1839     if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1840       // Don't release the root buffer until after execution completes.
1841       ShapedBuffer root_buffer_holder = result_buffer.release();
1842       se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
1843       device_state->ThenExecuteOnCallbackThread(
1844           device_state->compute_stream(),
1845           [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
1846             TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
1847           });
1848     }
1849   } else {
1850     outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
1851                                          client_, device, device_state));
1852   }
1853   return outputs;
1854 }
1855 
1856 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device) const1857 PjRtStreamExecutorExecutable::ExecuteHelper(
1858     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1859     const RunId& run_id, const ExecuteOptions& options,
1860     PjRtDevice* device) const {
1861   std::shared_ptr<DeviceAssignment> device_assignment;
1862   if (device == nullptr) {
1863     CHECK(device_assignment_ != nullptr);
1864     const int device_id = (*device_assignment_)(replica, partition);
1865     TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
1866     device_assignment = device_assignment_;
1867   } else {
1868     CHECK(device_assignment_ == nullptr);
1869     CHECK_EQ(replica, 0);
1870     CHECK_EQ(partition, 0);
1871     CHECK(addressable_devices_.empty());
1872     device_assignment = std::make_shared<DeviceAssignment>(1, 1);
1873     (*device_assignment)(0, 0) = device->id();
1874   }
1875 
1876   CHECK_EQ(device->task_id(), client_->task_id());
1877   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1878                            ->local_device_state()
1879                            ->device_ordinal();
1880   tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
1881   VLOG(3) << "Replica " << replica << ", partition " << partition
1882           << " mapped to device ordinal for execution: " << device_ordinal;
1883 
1884   // SPMD sharding produces a single executable for multiple partitions.
1885   int executable_idx = executables_.size() > 1 ? partition : 0;
1886 
1887   std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
1888   device_buffers.reserve(argument_handles.size());
1889   StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
1890       argument_handles, replica, partition, executable_idx, run_id, options,
1891       device, &device_buffers, std::move(device_assignment));
1892 
1893   if (!result_buffer_or_status.ok()) {
1894     LOG(ERROR) << "Execution of replica " << replica
1895                << " failed: " << result_buffer_or_status.status();
1896     return result_buffer_or_status.status();
1897   }
1898   ScopedShapedBuffer result_buffer =
1899       result_buffer_or_status.ConsumeValueOrDie();
1900 
1901   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1902   se::Stream* stream = device_state->compute_stream();
1903   StatusOr<EventPool::Handle> event_or =
1904       device_state->event_pool().ThenAllocateAndRecordEvent(stream);
1905   if (!event_or.ok()) {
1906     StallStreamOnError(device_state, stream);
1907     for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
1908       if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
1909         // Even though there was an error we need to call ConfirmDonation, which
1910         // renders b invalid, since the computation has been enqueued and b has
1911         // been donated.
1912         b.ConfirmDonation();
1913       }
1914     }
1915     return event_or.status();
1916   }
1917   auto definition_event = std::make_shared<BufferSequencingEvent>();
1918   definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
1919   std::vector<std::unique_ptr<PjRtBuffer>> outputs =
1920       MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
1921                         definition_event, device);
1922 
1923   for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
1924     // prefer_to_retain_reference=false because when using the
1925     // ComputeSynchronized allocation model we don't need to retain a reference
1926     // to the device_buffer during execution because by definition the compute
1927     // stream is synchronized past the execution.
1928     if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
1929       RecordUsage(std::move(b), device_state, device_state, definition_event,
1930                   stream,
1931                   /*prefer_to_retain_reference=*/false);
1932     } else {
1933       CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
1934       b.ConfirmDonation();
1935     }
1936   }
1937 
1938   return outputs;
1939 }
1940 
1941 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)1942 PjRtStreamExecutorExecutable::Execute(
1943     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1944     const ExecuteOptions& options) {
1945   if (device_assignment_ == nullptr) {
1946     return InvalidArgument("Execute expects a non-null device_assignment");
1947   }
1948 
1949   RunId run_id;
1950   tensorflow::profiler::TraceMeProducer activity(
1951       "PjRtStreamExecutorExecutable::Execute",
1952       tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1953 
1954   const int num_addressable_devices = addressable_devices_.size();
1955 
1956   if (argument_handles.size() != num_addressable_devices) {
1957     return InvalidArgument(
1958         "Attempted to execute with %d argument lists when local device "
1959         "count is %d (total replica count: %d, partition count: %d)",
1960         argument_handles.size(), num_addressable_devices, num_replicas(),
1961         num_partitions());
1962   }
1963 
1964   VLOG(1) << "Executing computation " << name()
1965           << "; num_replicas=" << num_replicas()
1966           << " num_partitions=" << num_partitions()
1967           << " num_addressable_devices=" << num_addressable_devices;
1968   std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
1969       num_addressable_devices);
1970   if (num_addressable_devices == 1) {
1971     // Fast-path if there is only one device — run the computation on the
1972     // current thread.
1973     const int replica = addressable_device_logical_ids_[0].replica;
1974     const int partition = addressable_device_logical_ids_[0].partition;
1975     results[0] =
1976         ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
1977   } else {
1978     absl::Mutex mu;
1979     int running = num_addressable_devices;
1980     int failed = 0;
1981     Status first_failure_status;
1982 
1983     for (int i = 0; i < num_addressable_devices; ++i) {
1984       const int replica = addressable_device_logical_ids_[i].replica;
1985       const int partition = addressable_device_logical_ids_[i].partition;
1986       PjRtDevice* device = addressable_devices_[i];
1987       const LocalDeviceState& device_state =
1988           *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1989                ->local_device_state();
1990       device_state.execute_thread()->Schedule([&, replica, partition, i] {
1991         results[i] = ExecuteHelper(argument_handles[i], replica, partition,
1992                                    run_id, options);
1993 
1994         absl::MutexLock lock(&mu);
1995         --running;
1996         if (!results[i].ok()) {
1997           if (failed == 0) {
1998             first_failure_status = results[i].status();
1999           }
2000           ++failed;
2001         }
2002       });
2003     }
2004 
2005     auto done_running_or_failed = [&]() {
2006       mu.AssertHeld();
2007       return running == 0 || failed > 0;
2008     };
2009     absl::MutexLock lock(&mu);
2010     mu.Await(absl::Condition(&done_running_or_failed));
2011     if (failed > 0) {
2012       auto done_running = [&]() {
2013         mu.AssertHeld();
2014         return running == 0;
2015       };
2016       // If execution does not terminate within a reasonable amount of time,
2017       // we may be stuck at a cross-replica barrier on-device. Terminate the
2018       // process since that's the only way we can escape this situation at the
2019       // moment (b/130629719).
2020       if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2021                                absl::Seconds(10))) {
2022         LOG(FATAL)
2023             << "Replicated computation launch failed, but not all replicas "
2024                "terminated. Aborting process to work around deadlock. "
2025                "Failure message (there may have been multiple failures, see "
2026                "the error log for all failures): \n\n"
2027             << first_failure_status.error_message();
2028       }
2029     }
2030   }
2031   VLOG(1) << "Replicated execution complete.";
2032 
2033   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2034       num_addressable_devices);
2035   for (int i = 0; i < num_addressable_devices; ++i) {
2036     const int replica = addressable_device_logical_ids_[i].replica;
2037     const int partition = addressable_device_logical_ids_[i].partition;
2038     auto& statusor = results[i];
2039     if (!statusor.ok()) {
2040       if (num_addressable_devices == 1) {
2041         return statusor.status();
2042       } else {
2043         return AppendStatus(
2044             statusor.status(),
2045             absl::StrFormat("while running replica %d and partition %d of a "
2046                             "replicated computation (other "
2047                             "replicas may have failed as well).",
2048                             replica, partition));
2049       }
2050     }
2051     wrapped_results[i] = std::move(statusor.ValueOrDie());
2052   }
2053   return wrapped_results;
2054 }
2055 
2056 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2057 PjRtStreamExecutorExecutable::ExecuteSharded(
2058     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2059     const ExecuteOptions& options) {
2060   if (device_assignment_ == nullptr) {
2061     return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2062   }
2063   for (int i = 0; i < addressable_devices_.size(); ++i) {
2064     if (addressable_devices_[i] == device) {
2065       VLOG(1) << "ExecuteShard executes computation " << name()
2066               << " on assigned replica/partition on device "
2067               << device->DebugString();
2068       return ExecuteHelper(
2069           argument_handles, addressable_device_logical_ids_[i].replica,
2070           addressable_device_logical_ids_[i].partition, RunId(), options);
2071     }
2072   }
2073   return InvalidArgument(
2074       "ExecuteShard attempted to execute on device id %d which is not "
2075       "addressable by this client",
2076       device->id());
2077 }
2078 
2079 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)2080 PjRtStreamExecutorExecutable::ExecutePortable(
2081     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2082     const ExecuteOptions& options) {
2083   if (device_assignment_ != nullptr) {
2084     return InvalidArgument("ExecutePortable gets a non-portable executable");
2085   }
2086   if (num_replicas() != 1 || num_partitions() != 1) {
2087     return InvalidArgument(
2088         "ExecutePortable expects a single-core executable but gets "
2089         "one with %d replica %d partition",
2090         num_replicas(), num_partitions());
2091   }
2092   if (device == nullptr) {
2093     return InvalidArgument("ExecutePortable expects a device to be specified");
2094   }
2095   VLOG(1) << "ExecutePortable executes single-core portable executable "
2096           << name();
2097   return ExecuteHelper(argument_handles,
2098                        /*replica=*/0,
2099                        /*partition=*/0, RunId(), options, device);
2100 }
2101 
2102 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2103 PjRtStreamExecutorExecutable::GetHloModules() const {
2104   std::vector<std::shared_ptr<HloModule>> modules;
2105   modules.reserve(executables().size());
2106   for (const auto& local_exec : executables()) {
2107     if (!local_exec->executable()->has_module()) {
2108       return InvalidArgument("Executable does not have HLO modules.");
2109     }
2110     modules.push_back(local_exec->executable()->shared_module());
2111   }
2112   return std::move(modules);
2113 }
2114 
Compile(const XlaComputation & computation,CompileOptions options)2115 StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
2116     const XlaComputation& computation, CompileOptions options) {
2117   tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2118 
2119   ExecutableBuildOptions& build_options = options.executable_build_options;
2120   if (!build_options.compile_thread_pool()) {
2121     build_options.set_compile_thread_pool(thread_pool());
2122   }
2123   if (!build_options.device_allocator()) {
2124     build_options.set_device_allocator(allocator());
2125   }
2126 
2127   int num_replicas;
2128   int num_partitions;
2129   std::shared_ptr<DeviceAssignment> device_assignment;
2130   TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2131       options.compile_portable_executable, &options.executable_build_options,
2132       [this](int num_replicas, int num_partitions) {
2133         return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2134       },
2135       &num_replicas, &num_partitions, &device_assignment));
2136 
2137   std::vector<const Shape*> argument_layout_pointers;
2138   TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2139       computation,
2140       [local_client = client()](Shape shape) {
2141         return local_client->backend()
2142             .transfer_manager()
2143             ->ChooseCompactLayoutForShape(shape);
2144       },
2145       options.argument_layouts, &options.executable_build_options,
2146       &argument_layout_pointers));
2147 
2148   // Find devices that are addressable by this client/task.
2149   std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
2150   std::vector<PjRtDevice*> addressable_devices;
2151   if (device_assignment != nullptr) {
2152     addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2153     addressable_devices.reserve(num_replicas * num_partitions);
2154     for (int replica = 0; replica < num_replicas; ++replica) {
2155       for (int partition = 0; partition < num_partitions; ++partition) {
2156         int device_id = (*device_assignment)(replica, partition);
2157         TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2158         if (device->task_id() != task_id()) {
2159           VLOG(3) << "Non-local device: " << device_id;
2160           continue;
2161         }
2162         PjRtExecutable::LogicalDeviceIds logica_device_ids;
2163         logica_device_ids.replica = replica;
2164         logica_device_ids.partition = partition;
2165         addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2166         addressable_devices.push_back(device);
2167       }
2168     }
2169     if (addressable_devices.empty()) {
2170       return InvalidArgument(
2171           "Device assignment (%s) does not have any local devices.",
2172           device_assignment->ToString());
2173     }
2174 
2175     if (build_options.device_ordinal() < 0) {
2176       build_options.set_device_ordinal(
2177           addressable_devices.front()->local_hardware_id());
2178     }
2179   }
2180 
2181   TF_ASSIGN_OR_RETURN(
2182       std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2183       client()->Compile(computation, argument_layout_pointers, build_options));
2184 
2185   auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
2186       std::move(local_executables), options.parameter_is_tupled_arguments,
2187       std::move(device_assignment), std::move(addressable_device_logical_ids),
2188       std::move(addressable_devices), this);
2189   TF_RETURN_IF_ERROR(
2190       executable->SetUpDonation(options.parameter_is_tupled_arguments));
2191   return std::unique_ptr<PjRtExecutable>(std::move(executable));
2192 }
2193 
2194 }  // namespace xla
2195