1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_launch_util.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/dma_helper.h"
27 #include "tensorflow/core/platform/mem.h"
28 #include "tensorflow/stream_executor/platform/port.h"
29 
30 namespace tensorflow {
31 
32 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)33 XlaDeviceAllocator::XlaDeviceAllocator(
34     stream_executor::StreamExecutor* stream_executor)
35     : stream_executor_(stream_executor) {}
36 
37 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
38 
Name()39 string XlaDeviceAllocator::Name() { return "xla"; }
40 
AllocateRaw(size_t alignment,size_t num_bytes)41 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
42   // We always return an empty XlaTensor object, encoded as an opaque tagged
43   // pointer. We can return an empty object and ignore num_bytes here because we
44   // have control over all of the uses of this device tensor, and can lazily
45   // allocate memory when used. This allows us to also know the shape of the
46   // allocated Tensor, which is useful if the device's tensor representation
47   // differs from the host.
48   return XlaTensor::ToOpaquePointer(new XlaTensor());
49 }
50 
DeallocateRaw(void * ptr)51 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
52   delete XlaTensor::FromOpaquePointer(ptr);
53 }
54 
GetStats()55 absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
56   absl::optional<stream_executor::AllocatorStats> se_stats =
57       stream_executor_->GetAllocatorStats();
58   if (!se_stats) {
59     return absl::nullopt;
60   }
61 
62   tensorflow::AllocatorStats tf_stats;
63   tf_stats.num_allocs = se_stats->num_allocs;
64   tf_stats.bytes_in_use = se_stats->bytes_in_use;
65   tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
66   tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
67   tf_stats.bytes_limit = se_stats->bytes_limit;
68   return tf_stats;
69 }
70 
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaCompiler::ShapeRepresentationFn shape_representation_fn,thread::ThreadPool * thread_pool)71 XlaDeviceContext::XlaDeviceContext(
72     std::shared_ptr<se::Stream> compute_stream,
73     std::shared_ptr<se::Stream> host_to_device_stream,
74     std::shared_ptr<se::Stream> device_to_host_stream,
75     std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
76     xla::LocalClient* client,
77     XlaCompiler::ShapeRepresentationFn shape_representation_fn,
78     thread::ThreadPool* thread_pool)
79     : stream_(std::move(compute_stream)),
80       host_to_device_stream_(std::move(host_to_device_stream)),
81       device_to_host_stream_(std::move(device_to_host_stream)),
82       device_to_device_streams_(std::move(device_to_device_streams)),
83       client_(client),
84       transfer_manager_(client->backend().transfer_manager()),
85       shape_representation_fn_(std::move(shape_representation_fn)),
86       thread_pool_(thread_pool) {
87   CHECK(host_to_device_stream_ != nullptr);
88   CHECK(stream_ != nullptr);
89   if (!shape_representation_fn_) {
90     shape_representation_fn_ = [](const TensorShape& shape,
91                                   DataType dtype) -> xla::StatusOr<xla::Shape> {
92       xla::Shape xla_shape;
93       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
94       return xla_shape;
95     };
96   }
97 }
98 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const99 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
100                                               Device* device,
101                                               Tensor* output_tensor,
102                                               StatusCallback done) const {
103   done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
104 }
105 
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done) const106 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
107                                              Device* device,
108                                              Tensor* device_tensor,
109                                              StatusCallback done) const {
110   if (cpu_tensor->NumElements() == 0) {
111     VLOG(2) << "CopyCPUTensorToDevice empty tensor";
112     done(Status::OK());
113     return;
114   }
115 
116   VLOG(2) << "CopyCPUTensorToDevice "
117           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
118           << " "
119           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
120           << " " << cpu_tensor->NumElements() << " "
121           << cpu_tensor->shape().DebugString() << " "
122           << device_tensor->shape().DebugString();
123 
124 
125   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
126   CHECK(xla_tensor);
127 
128   Status status = [&]() -> Status {
129     TF_ASSIGN_OR_RETURN(xla::Shape shape,
130                         shape_representation_fn_(device_tensor->shape(),
131                                                  device_tensor->dtype()));
132 
133     // The device tensor should always be fresh.
134     TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
135 
136     xla_tensor->set_host_tensor(*cpu_tensor);
137     TF_RETURN_IF_ERROR(
138         xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
139                                          stream_->parent()->device_ordinal()));
140 
141     // The cpu_tensor and literal that we created here hold the data of host
142     // tensor in descending layout. The layout could be different from layout in
143     // device_tensor (but the logical shape has to be the same). The
144     // transfer_manager is responsible to do corresponding transposing when
145     // transferring the data to device.
146     xla::BorrowingLiteral literal(
147         static_cast<const char*>(DMAHelper::base(cpu_tensor)),
148         xla::ShapeUtil::MakeShape(shape.element_type(),
149                                   xla::AsInt64Slice(shape.dimensions())));
150 
151     VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
152             << xla_tensor->shaped_buffer().ToString();
153     if (UseMultipleStreams() &&
154         !transfer_manager_->CanShapedBufferBeAccessedNow(
155             stream_->parent(), xla_tensor->shaped_buffer())) {
156       // Initially wait for the compute stream so that memory allocations are
157       // synchronized.
158       host_to_device_stream_->ThenWaitFor(stream_.get());
159     }
160 
161     TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
162         host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
163 
164     if (UseMultipleStreams()) {
165       auto event = std::make_shared<se::Event>(stream_->parent());
166       TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
167       host_to_device_stream_->ThenRecordEvent(event.get());
168       xla_tensor->ResetDefinitionEvent(std::move(event),
169                                        host_to_device_stream_.get());
170     }
171 
172     return Status::OK();
173   }();
174   if (!status.ok()) {
175     done(status);
176     return;
177   }
178 
179   // Create a reference to hold onto cpu_tensor until after the literal has
180   // been transferred
181   TensorReference ref(*cpu_tensor);
182   if (UseMultipleStreams()) {
183     // Unref the host tensor when the transfer completes.
184     // We don't defer the call to done() onto the stream here, and the reasons
185     // why this is correct are subtle. We assume that:
186     // a) all consumers of the device tensor will wait for its definition event.
187     // b) if the tensor is destroyed, then the memory allocator will not hand
188     //    out the same buffers until the transfer has completed.
189     host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
190     done(status);
191   } else {
192     host_to_device_stream_->ThenDoHostCallback([ref, done]() {
193       ref.Unref();
194       done(Status::OK());
195     });
196   }
197 }
198 
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)199 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
200                                              absl::string_view tensor_name,
201                                              Device* device, Tensor* cpu_tensor,
202                                              StatusCallback done) {
203   if (device_tensor->NumElements() == 0) {
204     VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
205     done(Status::OK());
206     return;
207   }
208   VLOG(2) << "CopyDeviceTensorToCPU "
209           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
210           << " "
211           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
212           << " " << device_tensor->NumElements() << " "
213           << cpu_tensor->shape().DebugString() << " "
214           << device_tensor->shape().DebugString();
215 
216   std::shared_ptr<se::Stream> device_to_host_stream;
217   if (device_to_host_stream_) {
218     device_to_host_stream = device_to_host_stream_;
219   } else {
220     stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
221         client_->mutable_backend()->BorrowStream(
222             stream_->parent()->device_ordinal());
223     if (!ptr_or_status.status().ok()) {
224       done(ptr_or_status.status());
225       return;
226     }
227     device_to_host_stream =
228         std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
229   }
230 
231   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
232   xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
233 
234   // Transfer manager requires the shape of the shaped buffer to be the same as
235   // literal shape except for the layout.  Set the literal to use xla_tensor's
236   // shape as it is derived from the cpu_tensor's shape using
237   // shape_representation_fn_.
238   xla::MutableBorrowingLiteral literal;
239   TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
240       xla::LayoutUtil::GetWithDefaultLayout(
241           xla_tensor->shaped_buffer().on_host_shape()),
242       cpu_tensor, &literal));
243 
244   TensorReference ref(*device_tensor);
245   // Explicitly capture device_to_host_stream to make sure the stream is alive
246   // before the transfer finishes.
247   transfer_manager_->TransferLiteralFromDevice(
248       device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
249       [ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
250         done([&]() -> Status {
251           VLOG(2) << "Transfer from device as literal: "
252                   << xla_tensor->shaped_buffer().ToString();
253           return status;
254         }());
255         ref.Unref();
256       });
257 }
258 
GetDeviceToDeviceStream()259 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
260   DCHECK_GT(device_to_device_streams_.size(), 0);
261   absl::MutexLock lock(&mu_);
262   int stream = next_stream_;
263   next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
264   return device_to_device_stream(stream);
265 }
266 
267 }  // namespace tensorflow
268