1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20 
21 #include <stdint.h>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/service/backend.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/stream_executor.h"
40 
41 namespace tensorflow {
42 
43 namespace {
44 
45 class BufferAllocStats {
46  public:
47   struct Stats {
48     int64 count = 0;
49     int64 size = 0;
50   };
51 
ReportAlloc(int64 device,int64 msize)52   Stats ReportAlloc(int64 device, int64 msize) {
53     mutex_lock lock(lock_);
54     Stats* device_stats = &stats_[device];
55     device_stats->count += 1;
56     device_stats->size += msize;
57     return *device_stats;
58   }
59 
ReportFree(int64 device,int64 msize)60   Stats ReportFree(int64 device, int64 msize) {
61     mutex_lock lock(lock_);
62     Stats* device_stats = &stats_[device];
63     device_stats->count -= 1;
64     device_stats->size -= msize;
65     return *device_stats;
66   }
67 
68  private:
69   mutable mutex lock_;
70   std::map<int64, Stats> stats_;
71 };
72 
73 const char* kTupleContainer = "tuples";
74 
get_uid()75 int64 get_uid() {
76   uint64 unsigned_rand = random::New64() & INT64_MAX;
77   return static_cast<int64>(unsigned_rand);
78 }
79 
GetAllocStats()80 BufferAllocStats* GetAllocStats() {
81   static BufferAllocStats* stats = new BufferAllocStats();
82   return stats;
83 }
84 
AllocateScopedShapedBuffer(xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer)85 Status AllocateScopedShapedBuffer(
86     xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
87     std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
88   auto transfer_manager = backend->transfer_manager();
89   auto allocator = backend->memory_allocator();
90   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
91 
92   // XLA may use a different representation on device than the representation on
93   // the host. XLA does not document any contract for the relationship between
94   // these representations :/ Right now, the device shape is always a superset
95   // of the host shape, meaning that for any valid ShapeIndex in the host shape
96   // that ShapeIndex is also valid in the device shape, but not vice versa. In
97   // particular, some host-side types are rewritten to be tuples. We rely on
98   // this property when making sub-buffers, because we assume that if the client
99   // requests the host-shape sub-buffer at index i, that will correspond to the
100   // right device-shape sub-buffer at the same index.
101   xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
102   VLOG(3) << "Allocating literal buffer: host_shape="
103           << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
104           << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
105 
106   // The ScopedShapedBuffer frees the buffers that have so far been allocated if
107   // it goes out of scope. That's useful if we return early as the result of an
108   // error allocating one of the later buffers.
109   *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
110       shape, on_device_shape, allocator, device_ordinal);
111   for (auto& index_to_buffer : (*buffer)->buffers()) {
112     xla::Shape subshape =
113         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
114     uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
115     TF_ASSIGN_OR_RETURN(
116         xla::OwningDeviceMemory buffer,
117         allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
118     // Move our buffer into shaped_buffer, which takes ownership of it.
119     index_to_buffer.second = buffer.Forget();
120     VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
121             << " index " << index_to_buffer.first.ToString();
122   }
123 
124   TF_RETURN_IF_ERROR(
125       transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
126 
127   return Status::OK();
128 }
129 
130 }  // namespace
131 
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,xla::DeviceMemoryAllocator * allocator)132 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
133                                          int device_ordinal,
134                                          xla::DeviceMemoryAllocator* allocator)
135     : size_(allocation.size()),
136       allocation_(allocation),
137       device_ordinal_(device_ordinal),
138       allocator_(allocator) {
139   if (VLOG_IS_ON(2)) {
140     auto stats =
141         GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
142     LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
143               << " count=" << stats.count << " size=" << stats.size;
144   }
145 }
146 
~XRTBufferAllocation()147 XRTBufferAllocation::~XRTBufferAllocation() {
148   if (VLOG_IS_ON(2)) {
149     GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
150   }
151   // Deallocate explicitly allows allocation_ to be null.
152   Status s = allocator_->Deallocate(device_ordinal_, allocation_);
153   // Nothing to do but check fail here if memory datastructures are corrupted.
154   CHECK(s.ok());
155   VLOG(2) << "Freed buffer at " << allocation_.opaque();
156 }
157 
allocation()158 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
159   return allocation_;
160 }
161 
DiscardAllocation()162 void XRTBufferAllocation::DiscardAllocation() {
163   // Replace the allocation with a null.
164   allocation_ = se::DeviceMemoryBase();
165 }
166 
XRTTupleAllocation(int device_ordinal,xla::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)167 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
168                                        xla::DeviceMemoryAllocator* allocator,
169                                        const xla::Shape& on_host_shape,
170                                        const xla::Shape& on_device_shape)
171     : device_ordinal_(device_ordinal),
172       allocator_(allocator),
173       on_host_shape_(on_host_shape),
174       on_device_shape_(on_device_shape),
175       buffers_(&on_device_shape_) {}
176 
~XRTTupleAllocation()177 XRTTupleAllocation::~XRTTupleAllocation() {
178   for (auto& buffer : buffers_) {
179     buffer.second->Unref();
180   }
181 }
182 
CreateAndTransfer(const xla::LiteralBase & literal,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)183 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
184     const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal,
185     XRTTupleAllocation** allocation) {
186   auto transfer_manager = backend->transfer_manager();
187   auto allocator = backend->memory_allocator();
188 
189   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
190   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
191       backend, device_ordinal, literal.shape(), &scoped_buffer));
192   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
193   TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
194       stream.get(), literal, *scoped_buffer));
195 
196   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
197   // won't be freed when the buffer goes out of scope at the end of this
198   // call. To avoid a leak, there must be no error-case returns from here until
199   // the end of the method.
200   auto shaped_buffer = scoped_buffer->release();
201   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
202                                        shaped_buffer.on_host_shape(),
203                                        shaped_buffer.on_device_shape());
204   (*allocation)
205       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
206   return Status::OK();
207 }
208 
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)209 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
210     const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
211     int device_ordinal, XRTTupleAllocation** allocation) {
212   auto allocator = backend->memory_allocator();
213 
214   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
215                                        shaped_buffer.on_host_shape(),
216                                        shaped_buffer.on_device_shape());
217   (*allocation)
218       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
219   return Status::OK();
220 }
221 
ToLiteral(xla::Backend * backend,int device_ordinal,xla::MutableLiteralBase * literal)222 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
223                                      xla::MutableLiteralBase* literal) {
224   auto transfer_manager = backend->transfer_manager();
225   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
226 
227   // Validate the allocation buffers as if nulls gets to
228   // TransferLiteralFromDevice() a CHECK is issued.
229   xla::ShapedBuffer shaped_buffer = ToShapedBuffer();
230   for (auto& index_buffer : shaped_buffer.buffers()) {
231     if (index_buffer.second.is_null()) {
232       return errors::InvalidArgument("Literal buffer at index ",
233                                      index_buffer.first.ToString(),
234                                      " has been released");
235     }
236   }
237   return transfer_manager->TransferLiteralFromDevice(stream.get(),
238                                                      shaped_buffer, *literal);
239 }
240 
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)241 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
242                                         const xla::Literal& literal) {
243   if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
244     return errors::InvalidArgument(
245         "New literal shape not matching the existing one: literal=",
246         xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
247         " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
248   }
249   auto transfer_manager = backend->transfer_manager();
250   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
251   return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
252                                                    ToShapedBuffer());
253 }
254 
DiscardAllocation(const xla::ShapeIndex & buffer_index)255 void XRTTupleAllocation::DiscardAllocation(
256     const xla::ShapeIndex& buffer_index) {
257   buffers_.element(buffer_index)->DiscardAllocation();
258 }
259 
on_host_shape()260 const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
261 
on_device_shape()262 const xla::Shape& XRTTupleAllocation::on_device_shape() {
263   return on_device_shape_;
264 }
265 
device_ordinal()266 int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
267 
root_allocation()268 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
269   return buffers_.element({})->allocation();
270 }
271 
Lookup(ResourceMgr * rm,int64 key,XRTTupleAllocation ** allocation)272 /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
273                                              XRTTupleAllocation** allocation) {
274   string key_string = absl::StrCat(key);
275   TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
276   return Status::OK();
277 }
278 
DeleteFromResourceManager(ResourceMgr * rm,int64 key)279 /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
280                                                                 int64 key) {
281   string key_string = absl::StrCat(key);
282   return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
283 }
284 
ReleaseAllAllocations(ResourceMgr * rm)285 /* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) {
286   VLOG(1) << "Releasing all XRT held device memory";
287   return rm->Cleanup(kTupleContainer);
288 }
289 
290 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
291 // readable. They need a type of const T& where in this case T is the
292 // following pointer.
293 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
294 
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)295 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
296     XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
297     XRTTupleAllocation** allocation, bool alias_parent_allocation) {
298   TF_ASSIGN_OR_RETURN(
299       const xla::Shape* host_sub_shape,
300       xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
301   TF_ASSIGN_OR_RETURN(
302       const xla::Shape* device_sub_shape,
303       xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
304 
305   *allocation =
306       new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
307                              *host_sub_shape, *device_sub_shape);
308   if (alias_parent_allocation) {
309     // Copy the subtree of allocations from the parent allocation.
310     (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
311     // Increment the refcount on each aliased buffer.
312     (*allocation)
313         ->buffers_.ForEachElement(
314             [](const xla::ShapeIndex& index,
315                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
316   } else {
317     // Find the buffers in the parent allocation that match the subtree, and
318     // move the parent allocation's buffer over to the new allocation.
319     (*allocation)
320         ->buffers_.ForEachMutableElement(
321             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
322               // Extend the allocation's index to the parent's frame by adding
323               // subshape as a prefix.
324               xla::ShapeIndex parent_index = subshape;
325               for (int i = 0; i < index.size(); ++i) {
326                 parent_index.push_back(index[i]);
327               }
328               *buffer = parent->buffers_.element(parent_index);
329               *parent->buffers_.mutable_element(parent_index) =
330                   new XRTBufferAllocation(se::DeviceMemoryBase(),
331                                           parent->device_ordinal(),
332                                           parent->allocator_);
333             });
334   }
335 
336   return Status::OK();
337 }
338 
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,xla::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)339 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
340     const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
341     xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
342     xla::Shape* device_shape) {
343   // Initialize both host and device shape to be the 'spine' of the new tuple
344   // shape, given by the shape of the tree of tuples.
345   *host_shape = elements.shape();
346   *device_shape = elements.shape();
347   // Now go over the leaves of the tree of tuples, and 'graft' the host/device
348   // shapes of the allocation at that leaf onto the expanded host/device shapes
349   // at the leaf position.
350   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
351       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
352         if (elements.IsLeaf(index)) {
353           if (element.allocation == nullptr) {
354             return errors::InvalidArgument(
355                 "MakeTuple elements has a null internal node at index ",
356                 index.ToString());
357           }
358           if (device_ordinal != element.allocation->device_ordinal() ||
359               allocator != element.allocation->allocator_) {
360             return errors::InvalidArgument(
361                 "MakeTuple elements must all be allocated on the same device "
362                 "as the destination.");
363           }
364           *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
365               element.allocation->on_host_shape();
366           *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
367               element.allocation->on_device_shape();
368         } else {
369           if (element.allocation != nullptr) {
370             return errors::InvalidArgument(
371                 "MakeTuple elements has a non-null internal node at index ",
372                 index.ToString());
373           }
374         }
375         return Status::OK();
376       }));
377   return Status::OK();
378 }
379 
MakeTuple(xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation)380 /*static*/ Status XRTTupleAllocation::MakeTuple(
381     xla::Backend* backend, int device_ordinal,
382     const xla::ShapeTree<ExpandedTupleInput>& elements,
383     XRTTupleAllocation** allocation) {
384   auto transfer_manager = backend->transfer_manager();
385   auto allocator = backend->memory_allocator();
386   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
387 
388   xla::Shape host_shape;
389   xla::Shape device_shape;
390   TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
391                                         &host_shape, &device_shape));
392 
393   // The aliasing is determined below based on whether or not all the inputs are
394   // released while being transferred. allocation_tmp is a local pointer that is
395   // copied to *allocation at the end only if the method succeeds.
396   auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
397                                                host_shape, device_shape);
398   core::ScopedUnref allocation_unref(allocation_tmp);
399   // First allocate device memory for the new tuple index tables, one at each
400   // internal node of the elements tree. Do this in a separate pass into a
401   // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
402   // an allocation fails. Make sure the shape has layout so that the code that
403   // writes index tables will be happy lower down.
404   xla::Shape spine_shape = elements.shape();
405   xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
406   auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
407       spine_shape, spine_shape, allocator, device_ordinal);
408   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
409       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
410         if (!elements.IsLeaf(index)) {
411           xla::Shape subshape =
412               xla::ShapeUtil::GetSubshape(device_shape, index);
413           uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
414           TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
415                               allocator->Allocate(device_ordinal, size,
416                                                   /*retry_on_failure=*/false));
417           VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
418                   << index.ToString();
419           // Move the new buffer into new_tuple_buffers, which takes ownership
420           // of it.
421           new_tuple_buffers->set_buffer(std::move(buffer), index);
422         }
423         return Status::OK();
424       }));
425   // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
426   // the newly-allocated index tables. Right now there's no owner for the new
427   // index tables, so next we will transfer ownership to the new allocation,
428   // taking care not to return early on any errors in the meantime.
429   xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
430   // Now fill in the remaining datastructures. After this ForEachElement
431   // completes:
432   //   1) Every leaf element of tuple_buffers will be the root buffer of
433   //      an existing allocation, and every internal element of tuple_buffers
434   //      will be a newly-allocated index table. tuple_buffers does not own any
435   //      of these.
436   //   2) Every element of allocation_tmp->buffers_ will be a correctly
437   //   constructed
438   //      XRTBufferAllocation wrapping the necessary allocations. For buffers in
439   //      existing allocations there will be a new reference owned by the new
440   //      allocation, and for newly-allocated index tables there will be a
441   //      single reference owned by the new allocation.
442   elements.ForEachElement([&](const xla::ShapeIndex& index,
443                               const ExpandedTupleInput& element) {
444     if (elements.IsLeaf(index)) {
445       allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
446                                                index);
447       tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
448       if (element.release_allocation_after_use) {
449         // Transfer the references from element's buffers to the new allocation
450         // rather than incrementing the refcount. The caller should have
451         // validated that release_allocation_after_use is false if
452         // element.allocation appears in more than one leaf.
453         element.allocation->buffers_.ForEachMutableElement(
454             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
455               *buffer = new XRTBufferAllocation(
456                   se::DeviceMemoryBase(), element.allocation->device_ordinal(),
457                   element.allocation->allocator_);
458             });
459       } else {
460         // Increment the refcount on each newly-aliased buffer.
461         element.allocation->buffers_.ForEachElement(
462             [](const xla::ShapeIndex& index,
463                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
464       }
465     } else {
466       // This is an internal node of the tuple tree so take ownership of the
467       // newly-created index table.
468       *allocation_tmp->buffers_.mutable_element(index) =
469           new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
470                                   allocator);
471     }
472   });
473   // Because the internal nodes of tuple_buffers are exactly the new index
474   // tables, WriteTupleIndexTables will write only the new index tables and not
475   // rewrite the index tables for the existing allocations.
476   TF_RETURN_IF_ERROR(
477       transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
478 
479   *allocation = allocation_tmp;
480   // Get another reference since allocation_tmp will be Unrefed automatically on
481   // exit.
482   (*allocation)->Ref();
483   return Status::OK();
484 }
485 
Intern(ResourceMgr * rm,int64 * key)486 Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
487   *key = get_uid();
488   string key_string = absl::StrCat(*key);
489   return rm->Create(kTupleContainer, key_string, this);
490 }
491 
IsExclusiveOwner()492 bool XRTTupleAllocation::IsExclusiveOwner() {
493   for (const auto& buffer : buffers_) {
494     if (!buffer.second->RefCountIsOne()) return false;
495   }
496   return true;
497 }
498 
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,xla::DeviceMemoryAllocator * allocator,int device_ordinal)499 void XRTTupleAllocation::InitializeFromShapedBuffer(
500     const xla::ShapedBuffer& shaped_buffer,
501     xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
502   for (auto& buffer : buffers_) {
503     // Make a reference-counted version of the allocated buffer.
504     buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
505                                             device_ordinal, allocator);
506   }
507 }
508 
ToShapedBuffer()509 xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
510   xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
511                                   allocator_->platform(), device_ordinal_);
512   for (const auto& buffer : buffers_) {
513     shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
514   }
515   return shaped_buffer;
516 }
517 
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)518 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
519                                            const xla::ShapeIndex& source_index,
520                                            const xla::ShapeIndex& dest_index) {
521   XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
522   XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
523   // We allow the destination size being zero, because there are cases where we
524   // are coming in later filling in null/uninitialized device buffers.
525   // In all other cases, the size of the new buffer must match.
526   if (source_buffer->size() != dest_buffer->size() &&
527       dest_buffer->size() != 0) {
528     return errors::InvalidArgument(
529         "Source buffer at index ", source_index.ToString(),
530         " does not match the size of destination buffer at index ",
531         dest_index.ToString(), ": ", source_buffer->size(), " vs ",
532         dest_buffer->size());
533   }
534   *buffers_.mutable_element(dest_index) = source_buffer;
535   source_buffer->Ref();
536   dest_buffer->Unref();
537   return Status::OK();
538 }
539 
540 xla::ShapeTree<xla::MaybeOwningDeviceMemory>
ToDeviceMemoryTree(const std::function<bool (const xla::ShapeIndex &)> & release_checker)541 XRTTupleAllocation::ToDeviceMemoryTree(
542     const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
543   xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
544   for (const auto& buffer : buffers_) {
545     if (!release_checker(buffer.first)) {
546       *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
547     } else {
548       *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
549           buffer.second->allocation(), device_ordinal_, allocator_);
550       DiscardAllocation(buffer.first);
551     }
552   }
553   return shaped_tree;
554 }
555 
556 }  // namespace tensorflow
557