1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/service/compiler.h"
24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33
34 using absl::StrCat;
35
36 namespace xla {
37
38 /* static */ tensorflow::mutex
39 TransferManager::platform_transfer_manager_mutex_(
40 tensorflow::LINKER_INITIALIZED);
41
42 /* static */ std::map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()43 TransferManager::GetPlatformTransferManagers() {
44 static auto* r = new std::map<se::Platform::Id, TransferManager::State>;
45 return r;
46 }
47
~TransferMetadata()48 TransferManager::TransferMetadata::~TransferMetadata() {}
49
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)50 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
51 se::Stream* stream, const ShapedBuffer& device_buffer,
52 const TransferMetadata* transfer_metadata) {
53 StatusOr<Literal> ret;
54
55 se::Stream* substream = stream->GetOrCreateSubStream();
56 substream->ThenWaitFor(stream);
57 auto cleanup = tensorflow::gtl::MakeCleanup(
58 [&]() { stream->ReturnSubStream(substream); });
59
60 tensorflow::Notification n;
61 Status s;
62 Literal literal(device_buffer.on_host_shape());
63 TransferLiteralFromDevice(
64 substream, device_buffer, &literal,
65 [&](Status status) {
66 s = status;
67 n.Notify();
68 },
69 transfer_metadata);
70 n.WaitForNotification();
71 if (!s.ok()) {
72 return s;
73 }
74 return std::move(literal);
75 }
76
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)77 Status TransferManager::TransferLiteralFromDevice(
78 se::Stream* stream, const ShapedBuffer& device_buffer,
79 const MutableBorrowingLiteral& literal,
80 const TransferMetadata* transfer_metadata) {
81 se::Stream* substream = stream->GetOrCreateSubStream();
82 auto cleanup = tensorflow::gtl::MakeCleanup(
83 [&]() { stream->ReturnSubStream(substream); });
84
85 Status ret;
86 tensorflow::Notification n;
87 TransferLiteralFromDevice(
88 substream, device_buffer, literal,
89 [&](Status status) {
90 ret = status;
91 n.Notify();
92 },
93 transfer_metadata);
94 n.WaitForNotification();
95 return ret;
96 }
97
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)98 Status TransferManager::TransferLiteralToDevice(
99 se::Stream* stream, const LiteralSlice& literal,
100 const ShapedBuffer& device_buffer,
101 const TransferMetadata* transfer_metadata) {
102 // Implement the synchronous version by waiting on the asynchronous version.
103 // Use a substream so that if we are called from a HostCallback we don't
104 // deadlock.
105 se::Stream* substream = stream->GetOrCreateSubStream();
106 substream->ThenWaitFor(stream);
107 auto cleanup = tensorflow::gtl::MakeCleanup(
108 [&]() { stream->ReturnSubStream(substream); });
109 TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
110 substream, literal, device_buffer, transfer_metadata));
111 return substream->BlockHostUntilDone();
112 }
113
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)114 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
115 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
116 const TransferMetadata* transfer_metadata) {
117 StatusOr<Literal> ret;
118 // Implement the synchronous version by waiting on the asynchronous version.
119 // Use a substream so that if we are called from a HostCallback we don't
120 // deadlock.
121 se::Stream* substream = stream->GetOrCreateSubStream();
122 auto cleanup = tensorflow::gtl::MakeCleanup(
123 [&]() { stream->ReturnSubStream(substream); });
124
125 tensorflow::Notification n;
126 Literal literal(shape);
127 Status s;
128 TransferArrayFromDevice(
129 substream, shape, source, &literal,
130 [&](Status status) {
131 s = status;
132 n.Notify();
133 },
134 transfer_metadata);
135 n.WaitForNotification();
136 if (!s.ok()) {
137 return s;
138 }
139 return std::move(literal);
140 }
141
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)142 Status TransferManager::TransferArrayToDevice(
143 se::Stream* stream, const LiteralSlice& literal,
144 const se::DeviceMemoryBase& dest,
145 const TransferMetadata* transfer_metadata) {
146 // Implement the synchronous version by waiting on the asynchronous version.
147 // Use a substream so that if we are called from a HostCallback we don't
148 // deadlock.
149 se::Stream* substream = stream->GetOrCreateSubStream();
150 auto cleanup = tensorflow::gtl::MakeCleanup(
151 [&]() { stream->ReturnSubStream(substream); });
152 TF_RETURN_IF_ERROR(
153 TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
154 return substream->BlockHostUntilDone();
155 }
156
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)157 Status TransferManager::TransferArrayToDeviceAsync(
158 se::Stream* stream, const LiteralSlice& literal,
159 const se::DeviceMemoryBase& dest,
160 const TransferMetadata* transfer_metadata) {
161 const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
162 TF_RET_CHECK(on_device_shape.IsArray())
163 << "On-device representation of "
164 << ShapeUtil::HumanString(literal.shape())
165 << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
166 if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
167 return FailedPrecondition(
168 "Allocation on device not large enough for array: "
169 "%d < %d",
170 dest.size(), GetByteSizeRequirement(on_device_shape));
171 }
172 ShapedBuffer shaped_buffer(on_device_shape,
173 stream->parent()->device_ordinal());
174 shaped_buffer.set_buffer(dest, /*index=*/{});
175 return TransferLiteralToDevice(stream, literal, shaped_buffer,
176 transfer_metadata);
177 }
178
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)179 void TransferManager::TransferArrayFromDevice(
180 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
181 const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
182 const TransferMetadata* transfer_metadata) {
183 if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
184 shape)) {
185 auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
186 " has a differently shaped representation on-device: ",
187 ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
188 return done(FailedPrecondition("%s", error));
189 }
190 if (source.size() < GetByteSizeRequirement(shape)) {
191 return done(
192 FailedPrecondition("Allocation on device not large enough for array: "
193 "%d < %d",
194 source.size(), GetByteSizeRequirement(shape)));
195 }
196 ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
197 shaped_buffer.set_buffer(source, /*index=*/{});
198 return TransferLiteralFromDevice(stream, shaped_buffer, literal,
199 std::move(done), transfer_metadata);
200 }
201
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)202 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
203 ShapedBuffer* device_buffer,
204 Shape* device_shape) {
205 DCHECK(device_shape->is_dynamic());
206 Shape original_device_shape = *device_shape;
207 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
208
209 TF_ASSIGN_OR_RETURN(auto compiler,
210 Compiler::GetForPlatform(stream->parent()->platform()));
211 TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
212 [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
213 const Shape& buffer_shape =
214 ShapeUtil::GetSubshape(*device_shape, index);
215 if (buffer_shape.IsTuple()) {
216 return Status::OK();
217 }
218 Shape& device_sub_shape =
219 *ShapeUtil::GetMutableSubshape(device_shape, index);
220 if (device_sub_shape.is_static()) {
221 return Status::OK();
222 }
223
224 // Read the dynamic shape metadata from the device stream.
225 auto shape_size_fn = compiler->ShapeSizeBytesFunction();
226 Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
227 const int64 offset = shape_size_fn(buffer_shape_static);
228 int64 metadata_size = shape_size_fn(buffer_shape) - offset;
229 if (metadata_size == 0) {
230 return InvalidArgument("Dynamic shape metadata size should not be 0");
231 }
232 auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
233 auto metadata_buffer =
234 stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
235 TF_ASSIGN_OR_RETURN(
236 auto metadata,
237 TransferArrayFromDevice(
238 stream,
239 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
240 metadata_buffer));
241
242 // Update shape size from metadata.
243 for (int64 i = 0; i < metadata.element_count(); ++i) {
244 device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
245 }
246 return Status::OK();
247 }));
248 device_shape->clear_dynamic_dimensions();
249
250 TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
251 original_device_shape));
252 return Status::OK();
253 }
254
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)255 /* static */ void TransferManager::RegisterTransferManager(
256 se::Platform::Id platform_id,
257 TransferManagerCreationFunction creation_function) {
258 tensorflow::mutex_lock lock(
259 TransferManager::platform_transfer_manager_mutex_);
260 auto* managers = GetPlatformTransferManagers();
261 CHECK(managers->find(platform_id) == managers->end());
262 (*managers)[platform_id].creation_function = creation_function;
263 }
264
GetForPlatform(const se::Platform * platform)265 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
266 const se::Platform* platform) {
267 tensorflow::mutex_lock lock(
268 TransferManager::platform_transfer_manager_mutex_);
269 auto* managers = GetPlatformTransferManagers();
270
271 auto it = managers->find(platform->id());
272 if (it == managers->end()) {
273 return NotFound(
274 "could not find registered transfer manager for platform %s -- check "
275 "target linkage",
276 platform->Name());
277 }
278
279 if (it->second.manager == nullptr) {
280 // Lazily create the transfer manager the first time it is needed
281 it->second.manager = (*it->second.creation_function)();
282 }
283
284 return it->second.manager.get();
285 }
286
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)287 Status TransferManager::WriteTupleIndexTables(
288 se::Stream* stream, const ShapedBuffer& device_buffer) {
289 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
290 return stream->BlockHostUntilDone();
291 }
292
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)293 Status TransferManager::WriteTupleIndexTablesAsync(
294 se::Stream* stream, const ShapedBuffer& device_buffer) {
295 VLOG(2) << "Writing tuple index tables for " << device_buffer;
296
297 return ShapeUtil::ForEachSubshapeWithStatus(
298 device_buffer.on_device_shape(),
299 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
300 if (device_subshape.IsTuple() &&
301 ShapeUtil::TupleElementCount(device_subshape) > 0) {
302 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
303 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
304 device_memory.size());
305
306 std::vector<se::DeviceMemoryBase> elements;
307 ShapeIndex element_index = index;
308 for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
309 ++i) {
310 element_index.push_back(i);
311 elements.push_back(device_buffer.buffer(element_index));
312 element_index.pop_back();
313 }
314 return WriteSingleTupleIndexTable(stream, elements, device_subshape,
315 &device_memory);
316 }
317
318 return Status::OK();
319 });
320 }
321
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)322 Status TransferManager::WriteRootTupleIndexTable(
323 se::Stream* stream, const ShapedBuffer& device_buffer) {
324 TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
325 if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
326 return Status::OK();
327 }
328 se::DeviceMemoryBase device_memory = device_buffer.buffer({});
329 TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
330 device_memory.size());
331
332 std::vector<se::DeviceMemoryBase> elements;
333 for (int64 i = 0;
334 i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
335 elements.push_back(device_buffer.buffer({i}));
336 }
337 return WriteSingleTupleIndexTable(
338 stream, elements, device_buffer.on_device_shape(), &device_memory);
339 }
340
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)341 Status TransferManager::WriteRootTupleIndexTable(
342 se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
343 TF_RET_CHECK(buffer_tree.shape().IsTuple());
344 if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
345 return Status::OK();
346 }
347 se::DeviceMemoryBase device_memory =
348 buffer_tree.element({}).AsDeviceMemoryBase();
349 TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
350 device_memory.size());
351
352 std::vector<se::DeviceMemoryBase> elements;
353 for (int64 i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
354 ++i) {
355 elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
356 }
357 return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
358 &device_memory);
359 }
360
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64 size,void * destination)361 Status TransferManager::TransferBufferFromDevice(
362 se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
363 void* destination) {
364 if (source.size() < size) {
365 return FailedPrecondition(
366 "Source allocation on device not large enough for data transfer: "
367 "%d < %d",
368 source.size(), size);
369 }
370 stream->ThenMemcpy(destination, source, size);
371 return Status::OK();
372 }
373
TransferBufferToDevice(se::Stream * stream,int64 size,const void * source,se::DeviceMemoryBase * destination)374 Status TransferManager::TransferBufferToDevice(
375 se::Stream* stream, int64 size, const void* source,
376 se::DeviceMemoryBase* destination) {
377 if (destination->size() < size) {
378 return FailedPrecondition(
379 "Destination allocation on device not large enough for data transfer: "
380 "%d < %d",
381 destination->size(), size);
382 }
383 stream->ThenMemcpy(destination, source, size);
384 return Status::OK();
385 }
386
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)387 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
388 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
389 int device_ordinal) {
390 if (!LayoutUtil::HasLayout(on_host_shape)) {
391 return InvalidArgument("Shape must have a layout: %s",
392 ShapeUtil::HumanStringWithLayout(on_host_shape));
393 }
394 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
395 Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
396 TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
397
398 ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
399 device_ordinal);
400
401 // Allocate an appropriate sized buffer for each element in the shape
402 // including the tuple pointer arrays.
403 for (auto& pair : shaped_buffer.buffers()) {
404 const ShapeIndex& index = pair.first;
405 se::DeviceMemoryBase& memory_base = pair.second;
406 const Shape& subshape =
407 ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
408 TF_ASSIGN_OR_RETURN(auto memory,
409 allocator->Allocate(shaped_buffer.device_ordinal(),
410 GetByteSizeRequirement(subshape),
411 /*retry_on_failure=*/true,
412 subshape.layout().memory_space()));
413 // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
414 memory_base = memory.Release();
415 }
416
417 return std::move(shaped_buffer);
418 }
419
ChooseCompactLayoutForShape(const Shape & host_shape) const420 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
421 const Shape& host_shape) const {
422 return LayoutUtil::GetWithDefaultLayout(host_shape);
423 }
424
425 } // namespace xla
426