1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/service/compiler.h"
24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33 
34 using absl::StrCat;
35 
36 namespace xla {
37 
38 /* static */ tensorflow::mutex
39     TransferManager::platform_transfer_manager_mutex_(
40         tensorflow::LINKER_INITIALIZED);
41 
42 /* static */ std::map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()43 TransferManager::GetPlatformTransferManagers() {
44   static auto* r = new std::map<se::Platform::Id, TransferManager::State>;
45   return r;
46 }
47 
~TransferMetadata()48 TransferManager::TransferMetadata::~TransferMetadata() {}
49 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)50 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
51     se::Stream* stream, const ShapedBuffer& device_buffer,
52     const TransferMetadata* transfer_metadata) {
53   StatusOr<Literal> ret;
54 
55   se::Stream* substream = stream->GetOrCreateSubStream();
56   substream->ThenWaitFor(stream);
57   auto cleanup = tensorflow::gtl::MakeCleanup(
58       [&]() { stream->ReturnSubStream(substream); });
59 
60   tensorflow::Notification n;
61   Status s;
62   Literal literal(device_buffer.on_host_shape());
63   TransferLiteralFromDevice(
64       substream, device_buffer, &literal,
65       [&](Status status) {
66         s = status;
67         n.Notify();
68       },
69       transfer_metadata);
70   n.WaitForNotification();
71   if (!s.ok()) {
72     return s;
73   }
74   return std::move(literal);
75 }
76 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)77 Status TransferManager::TransferLiteralFromDevice(
78     se::Stream* stream, const ShapedBuffer& device_buffer,
79     const MutableBorrowingLiteral& literal,
80     const TransferMetadata* transfer_metadata) {
81   se::Stream* substream = stream->GetOrCreateSubStream();
82   auto cleanup = tensorflow::gtl::MakeCleanup(
83       [&]() { stream->ReturnSubStream(substream); });
84 
85   Status ret;
86   tensorflow::Notification n;
87   TransferLiteralFromDevice(
88       substream, device_buffer, literal,
89       [&](Status status) {
90         ret = status;
91         n.Notify();
92       },
93       transfer_metadata);
94   n.WaitForNotification();
95   return ret;
96 }
97 
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)98 Status TransferManager::TransferLiteralToDevice(
99     se::Stream* stream, const LiteralSlice& literal,
100     const ShapedBuffer& device_buffer,
101     const TransferMetadata* transfer_metadata) {
102   // Implement the synchronous version by waiting on the asynchronous version.
103   // Use a substream so that if we are called from a HostCallback we don't
104   // deadlock.
105   se::Stream* substream = stream->GetOrCreateSubStream();
106   substream->ThenWaitFor(stream);
107   auto cleanup = tensorflow::gtl::MakeCleanup(
108       [&]() { stream->ReturnSubStream(substream); });
109   TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
110       substream, literal, device_buffer, transfer_metadata));
111   return substream->BlockHostUntilDone();
112 }
113 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)114 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
115     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
116     const TransferMetadata* transfer_metadata) {
117   StatusOr<Literal> ret;
118   // Implement the synchronous version by waiting on the asynchronous version.
119   // Use a substream so that if we are called from a HostCallback we don't
120   // deadlock.
121   se::Stream* substream = stream->GetOrCreateSubStream();
122   auto cleanup = tensorflow::gtl::MakeCleanup(
123       [&]() { stream->ReturnSubStream(substream); });
124 
125   tensorflow::Notification n;
126   Literal literal(shape);
127   Status s;
128   TransferArrayFromDevice(
129       substream, shape, source, &literal,
130       [&](Status status) {
131         s = status;
132         n.Notify();
133       },
134       transfer_metadata);
135   n.WaitForNotification();
136   if (!s.ok()) {
137     return s;
138   }
139   return std::move(literal);
140 }
141 
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)142 Status TransferManager::TransferArrayToDevice(
143     se::Stream* stream, const LiteralSlice& literal,
144     const se::DeviceMemoryBase& dest,
145     const TransferMetadata* transfer_metadata) {
146   // Implement the synchronous version by waiting on the asynchronous version.
147   // Use a substream so that if we are called from a HostCallback we don't
148   // deadlock.
149   se::Stream* substream = stream->GetOrCreateSubStream();
150   auto cleanup = tensorflow::gtl::MakeCleanup(
151       [&]() { stream->ReturnSubStream(substream); });
152   TF_RETURN_IF_ERROR(
153       TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
154   return substream->BlockHostUntilDone();
155 }
156 
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)157 Status TransferManager::TransferArrayToDeviceAsync(
158     se::Stream* stream, const LiteralSlice& literal,
159     const se::DeviceMemoryBase& dest,
160     const TransferMetadata* transfer_metadata) {
161   const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
162   TF_RET_CHECK(on_device_shape.IsArray())
163       << "On-device representation of "
164       << ShapeUtil::HumanString(literal.shape())
165       << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
166   if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
167     return FailedPrecondition(
168         "Allocation on device not large enough for array: "
169         "%d < %d",
170         dest.size(), GetByteSizeRequirement(on_device_shape));
171   }
172   ShapedBuffer shaped_buffer(on_device_shape,
173                              stream->parent()->device_ordinal());
174   shaped_buffer.set_buffer(dest, /*index=*/{});
175   return TransferLiteralToDevice(stream, literal, shaped_buffer,
176                                  transfer_metadata);
177 }
178 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)179 void TransferManager::TransferArrayFromDevice(
180     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
181     const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
182     const TransferMetadata* transfer_metadata) {
183   if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
184                                                  shape)) {
185     auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
186                         " has a differently shaped representation on-device: ",
187                         ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
188     return done(FailedPrecondition("%s", error));
189   }
190   if (source.size() < GetByteSizeRequirement(shape)) {
191     return done(
192         FailedPrecondition("Allocation on device not large enough for array: "
193                            "%d < %d",
194                            source.size(), GetByteSizeRequirement(shape)));
195   }
196   ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
197   shaped_buffer.set_buffer(source, /*index=*/{});
198   return TransferLiteralFromDevice(stream, shaped_buffer, literal,
199                                    std::move(done), transfer_metadata);
200 }
201 
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)202 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
203                                           ShapedBuffer* device_buffer,
204                                           Shape* device_shape) {
205   DCHECK(device_shape->is_dynamic());
206   Shape original_device_shape = *device_shape;
207   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
208 
209   TF_ASSIGN_OR_RETURN(auto compiler,
210                       Compiler::GetForPlatform(stream->parent()->platform()));
211   TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
212       [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
213         const Shape& buffer_shape =
214             ShapeUtil::GetSubshape(*device_shape, index);
215         if (buffer_shape.IsTuple()) {
216           return Status::OK();
217         }
218         Shape& device_sub_shape =
219             *ShapeUtil::GetMutableSubshape(device_shape, index);
220         if (device_sub_shape.is_static()) {
221           return Status::OK();
222         }
223 
224         // Read the dynamic shape metadata from the device stream.
225         auto shape_size_fn = compiler->ShapeSizeBytesFunction();
226         Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
227         const int64 offset = shape_size_fn(buffer_shape_static);
228         int64 metadata_size = shape_size_fn(buffer_shape) - offset;
229         if (metadata_size == 0) {
230           return InvalidArgument("Dynamic shape metadata size should not be 0");
231         }
232         auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
233         auto metadata_buffer =
234             stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
235         TF_ASSIGN_OR_RETURN(
236             auto metadata,
237             TransferArrayFromDevice(
238                 stream,
239                 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
240                 metadata_buffer));
241 
242         // Update shape size from metadata.
243         for (int64 i = 0; i < metadata.element_count(); ++i) {
244           device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
245         }
246         return Status::OK();
247       }));
248   device_shape->clear_dynamic_dimensions();
249 
250   TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
251                                                    original_device_shape));
252   return Status::OK();
253 }
254 
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)255 /* static */ void TransferManager::RegisterTransferManager(
256     se::Platform::Id platform_id,
257     TransferManagerCreationFunction creation_function) {
258   tensorflow::mutex_lock lock(
259       TransferManager::platform_transfer_manager_mutex_);
260   auto* managers = GetPlatformTransferManagers();
261   CHECK(managers->find(platform_id) == managers->end());
262   (*managers)[platform_id].creation_function = creation_function;
263 }
264 
GetForPlatform(const se::Platform * platform)265 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
266     const se::Platform* platform) {
267   tensorflow::mutex_lock lock(
268       TransferManager::platform_transfer_manager_mutex_);
269   auto* managers = GetPlatformTransferManagers();
270 
271   auto it = managers->find(platform->id());
272   if (it == managers->end()) {
273     return NotFound(
274         "could not find registered transfer manager for platform %s -- check "
275         "target linkage",
276         platform->Name());
277   }
278 
279   if (it->second.manager == nullptr) {
280     // Lazily create the transfer manager the first time it is needed
281     it->second.manager = (*it->second.creation_function)();
282   }
283 
284   return it->second.manager.get();
285 }
286 
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)287 Status TransferManager::WriteTupleIndexTables(
288     se::Stream* stream, const ShapedBuffer& device_buffer) {
289   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
290   return stream->BlockHostUntilDone();
291 }
292 
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)293 Status TransferManager::WriteTupleIndexTablesAsync(
294     se::Stream* stream, const ShapedBuffer& device_buffer) {
295   VLOG(2) << "Writing tuple index tables for " << device_buffer;
296 
297   return ShapeUtil::ForEachSubshapeWithStatus(
298       device_buffer.on_device_shape(),
299       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
300         if (device_subshape.IsTuple() &&
301             ShapeUtil::TupleElementCount(device_subshape) > 0) {
302           se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
303           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
304                        device_memory.size());
305 
306           std::vector<se::DeviceMemoryBase> elements;
307           ShapeIndex element_index = index;
308           for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
309                ++i) {
310             element_index.push_back(i);
311             elements.push_back(device_buffer.buffer(element_index));
312             element_index.pop_back();
313           }
314           return WriteSingleTupleIndexTable(stream, elements, device_subshape,
315                                             &device_memory);
316         }
317 
318         return Status::OK();
319       });
320 }
321 
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)322 Status TransferManager::WriteRootTupleIndexTable(
323     se::Stream* stream, const ShapedBuffer& device_buffer) {
324   TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
325   if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
326     return Status::OK();
327   }
328   se::DeviceMemoryBase device_memory = device_buffer.buffer({});
329   TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
330                device_memory.size());
331 
332   std::vector<se::DeviceMemoryBase> elements;
333   for (int64 i = 0;
334        i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
335     elements.push_back(device_buffer.buffer({i}));
336   }
337   return WriteSingleTupleIndexTable(
338       stream, elements, device_buffer.on_device_shape(), &device_memory);
339 }
340 
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)341 Status TransferManager::WriteRootTupleIndexTable(
342     se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
343   TF_RET_CHECK(buffer_tree.shape().IsTuple());
344   if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
345     return Status::OK();
346   }
347   se::DeviceMemoryBase device_memory =
348       buffer_tree.element({}).AsDeviceMemoryBase();
349   TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
350                device_memory.size());
351 
352   std::vector<se::DeviceMemoryBase> elements;
353   for (int64 i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
354        ++i) {
355     elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
356   }
357   return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
358                                     &device_memory);
359 }
360 
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64 size,void * destination)361 Status TransferManager::TransferBufferFromDevice(
362     se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
363     void* destination) {
364   if (source.size() < size) {
365     return FailedPrecondition(
366         "Source allocation on device not large enough for data transfer: "
367         "%d < %d",
368         source.size(), size);
369   }
370   stream->ThenMemcpy(destination, source, size);
371   return Status::OK();
372 }
373 
TransferBufferToDevice(se::Stream * stream,int64 size,const void * source,se::DeviceMemoryBase * destination)374 Status TransferManager::TransferBufferToDevice(
375     se::Stream* stream, int64 size, const void* source,
376     se::DeviceMemoryBase* destination) {
377   if (destination->size() < size) {
378     return FailedPrecondition(
379         "Destination allocation on device not large enough for data transfer: "
380         "%d < %d",
381         destination->size(), size);
382   }
383   stream->ThenMemcpy(destination, source, size);
384   return Status::OK();
385 }
386 
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)387 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
388     const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
389     int device_ordinal) {
390   if (!LayoutUtil::HasLayout(on_host_shape)) {
391     return InvalidArgument("Shape must have a layout: %s",
392                            ShapeUtil::HumanStringWithLayout(on_host_shape));
393   }
394   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
395   Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
396   TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
397 
398   ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
399                                    device_ordinal);
400 
401   // Allocate an appropriate sized buffer for each element in the shape
402   // including the tuple pointer arrays.
403   for (auto& pair : shaped_buffer.buffers()) {
404     const ShapeIndex& index = pair.first;
405     se::DeviceMemoryBase& memory_base = pair.second;
406     const Shape& subshape =
407         ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
408     TF_ASSIGN_OR_RETURN(auto memory,
409                         allocator->Allocate(shaped_buffer.device_ordinal(),
410                                             GetByteSizeRequirement(subshape),
411                                             /*retry_on_failure=*/true,
412                                             subshape.layout().memory_space()));
413     // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
414     memory_base = memory.Release();
415   }
416 
417   return std::move(shaped_buffer);
418 }
419 
ChooseCompactLayoutForShape(const Shape & host_shape) const420 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
421     const Shape& host_shape) const {
422   return LayoutUtil::GetWithDefaultLayout(host_shape);
423 }
424 
425 }  // namespace xla
426