1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/local_client.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 
30 using xla::source_map_util::InvalidParameterArgument;
31 
32 namespace xla {
33 
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36                                                 Backend* backend) {
37   if (device_ordinal < 0) {
38     device_ordinal = backend->default_device_ordinal();
39   }
40   return backend->BorrowStream(device_ordinal);
41 }
42 }  // namespace
43 
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45                                  Backend* backend,
46                                  ExecutableBuildOptions build_options)
47     : executable_(std::move(executable)),
48       backend_(backend),
49       build_options_(std::move(build_options)) {
50   CHECK_GE(build_options_.device_ordinal(), 0)
51       << "Must have a valid device ordinal that the executable was built for.";
52 }
53 
ValidateExecutionOptions(const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55     const ExecutableRunOptions& run_options, const Backend& backend) {
56   if (run_options.stream() != nullptr) {
57     if (!run_options.stream()->ok()) {
58       return InvalidArgument("stream is uninitialized or in an error state");
59     }
60 
61     // Check stream matches service platform.
62     const se::Platform* stream_platform =
63         run_options.stream()->parent()->platform();
64     if (stream_platform != backend_->platform()) {
65       return InvalidArgument(
66           "stream is for platform %s, but service targets platform %s",
67           stream_platform->Name(), backend_->platform()->Name());
68     }
69 
70     // Cannot specify device_ordinal with a stream. The stream determines these
71     // values.
72     if (run_options.device_ordinal() != -1) {
73       return InvalidArgument(
74           "cannot set both device ordinal and stream options in "
75           "ExecutableRunOptions; the stream determines the device ordinal");
76     }
77   }
78 
79   // Verify that the device the executable was built for is equivalent
80   // to the device it will run on.
81   int run_device_ordinal = run_options.device_ordinal();
82   if (run_device_ordinal == -1) {
83     run_device_ordinal = run_options.stream() != nullptr
84                              ? run_options.stream()->parent()->device_ordinal()
85                              : backend_->default_device_ordinal();
86   }
87   TF_ASSIGN_OR_RETURN(bool devices_equivalent,
88                       backend_->devices_equivalent(
89                           run_device_ordinal, build_options_.device_ordinal()));
90   if (!devices_equivalent) {
91     TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
92                         backend_->stream_executor(run_device_ordinal));
93     TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
94                         backend_->stream_executor(build_device_ordinal()));
95     return InvalidArgument(
96         "executable is built for device %s of type \"%s\"; cannot run it on "
97         "device %s of type \"%s\"",
98         backend_->device_name(build_device_ordinal()),
99         build_executor->GetDeviceDescription().name(),
100         backend_->device_name(run_device_ordinal),
101         run_executor->GetDeviceDescription().name());
102   }
103 
104   if (!run_options.allocator()) {
105     return InvalidArgument("an allocator must be provided to ExecuteLocally");
106   }
107 
108   if (run_options.allocator()->platform() != backend.platform()) {
109     return InvalidArgument(
110         "allocator platform (%s) does not match service platform (%s)",
111         run_options.allocator()->platform()->Name(),
112         backend.platform()->Name());
113   }
114 
115   return Status::OK();
116 }
117 
118 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
RunHelper(const absl::Span<const Shape * const> argument_shapes,ExecutableRunOptions run_options)119 LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
120                            ExecutableRunOptions run_options) {
121   const ComputationLayout& computation_layout =
122       executable_->module_config().entry_computation_layout();
123 
124   // Check argument number, shapes, and layouts.
125   const int argument_shapes_size = argument_shapes.size();
126   if (argument_shapes_size != computation_layout.parameter_count()) {
127     return InvalidArgument(
128         "invalid number of arguments for computation: expected %d, got %u",
129         computation_layout.parameter_count(), argument_shapes.size());
130   }
131   for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
132     if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
133             *argument_shapes[i], /*minor_to_major_only=*/true)) {
134       return InvalidParameterArgument(
135           executable_.get(), i,
136           "Argument does not match host shape or layout of computation "
137           "parameter "
138           "%d: want %s, got %s",
139           i,
140           ShapeUtil::HumanStringWithLayout(
141               computation_layout.parameter_layout(i).shape()),
142           ShapeUtil::HumanStringWithLayout(*argument_shapes[i]));
143     }
144   }
145 
146   TF_RETURN_IF_ERROR(ValidateExecutionOptions(run_options, *backend_));
147 
148   StreamPool::Ptr stream;
149   if (run_options.stream() == nullptr) {
150     // NB!  The lifetime of `stream` needs to match the lifetime of
151     // `service_options` (otherwise we will end up using a returned stream in
152     // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
153     // scope.
154     TF_ASSIGN_OR_RETURN(
155         stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
156     run_options.set_stream(stream.get());
157   }
158   if (run_options.allocator() == nullptr) {
159     run_options.set_allocator(backend_->memory_allocator());
160   }
161 
162   // For local client execution on CPU backends:
163   // *) The thread pool used for eigen CPU ops is from
164   //    ExecutableRunOptions.eigen_intra_op_thread_pool.
165   // *) The thread pool used for XLA CPU ops is from
166   //    backend_->eigen_intra_op_thread_pool().
167   ServiceExecutableRunOptions service_options(run_options,
168                                               backend_->StreamBorrower());
169   return std::make_pair(service_options, std::move(stream));
170 }
171 
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)172 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
173     const absl::Span<const ShapedBuffer* const> arguments,
174     ExecutableRunOptions run_options) {
175   std::vector<const Shape*> argument_shapes;
176   argument_shapes.reserve(arguments.size());
177   for (const ShapedBuffer* const arg : arguments) {
178     argument_shapes.push_back(&arg->on_device_shape());
179   }
180   return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
181       argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
182         return RunAsync(arguments, options);
183       });
184 }
185 
Run(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)186 StatusOr<ExecutionOutput> LocalExecutable::Run(
187     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
188   std::vector<const Shape*> argument_shapes;
189   argument_shapes.reserve(arguments.size());
190   for (const ExecutionInput& arg : arguments) {
191     argument_shapes.push_back(&arg.shape());
192   }
193   return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
194       argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
195         return RunAsync(argument_shapes, std::move(arguments), options);
196       });
197 }
198 
DumpArguments(const Backend * backend,const Executable * executable,const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream)199 static std::shared_ptr<HloSnapshot> DumpArguments(
200     const Backend* backend, const Executable* executable,
201     const absl::Span<const ShapedBuffer* const> arguments, se::Stream* stream) {
202   auto snapshot = std::make_shared<HloSnapshot>();
203   snapshot->set_execution_platform(backend->platform()->Name());
204   *snapshot->mutable_hlo() = *executable->hlo_proto();
205   for (const ShapedBuffer* arg : arguments) {
206     auto literal = std::make_shared<Literal>(arg->on_host_shape());
207     backend->transfer_manager()->TransferLiteralFromDevice(
208         stream, *arg, literal.get(), [snapshot, literal](Status status) {
209           if (!status.ok()) {
210             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
211                           "failed: "
212                        << status;
213             return;
214           }
215           *snapshot->add_arguments() = literal->ToProto();
216         });
217   }
218   return snapshot;
219 }
220 
DumpOutputsAndSaveSnapshot(const Backend * backend,const ShapedBuffer & outputs,std::shared_ptr<HloSnapshot> snapshot,se::Stream * stream)221 static void DumpOutputsAndSaveSnapshot(const Backend* backend,
222                                        const ShapedBuffer& outputs,
223                                        std::shared_ptr<HloSnapshot> snapshot,
224                                        se::Stream* stream) {
225   auto literal = std::make_shared<Literal>(outputs.on_host_shape());
226   backend->transfer_manager()->TransferLiteralFromDevice(
227       stream, outputs, literal.get(),
228       [snapshot{std::move(snapshot)}, literal](Status status) {
229         if (status.ok()) {
230           *snapshot->mutable_result() = literal->ToProto();
231         } else {
232           LOG(ERROR)
233               << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
234               << status;
235         }
236         DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
237       });
238 }
239 
RunAsync(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)240 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
241     const absl::Span<const ShapedBuffer* const> arguments,
242     ExecutableRunOptions run_options) {
243   std::vector<const Shape*> argument_shapes;
244   argument_shapes.reserve(arguments.size());
245   for (const ShapedBuffer* const arg : arguments) {
246     argument_shapes.push_back(&arg->on_device_shape());
247   }
248   TF_ASSIGN_OR_RETURN(auto options_and_stream,
249                       RunHelper(argument_shapes, run_options));
250   se::Stream* stream = run_options.stream();
251 
252   std::shared_ptr<HloSnapshot> snapshot;
253   if (executable_->dumping_snapshot()) {
254     snapshot = DumpArguments(backend_, executable_.get(), arguments, stream);
255   }
256 
257   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
258                       executable_->ExecuteAsyncOnStreamWrapper(
259                           &options_and_stream.first, arguments));
260 
261   // Transfer the outputs and save the snapshot to disk.
262   if (snapshot) {
263     DumpOutputsAndSaveSnapshot(backend_, outputs, std::move(snapshot), stream);
264   }
265 
266   return std::move(outputs);
267 }
268 
MaybeOwningShapeTreeToShapedBuffer(const ShapeTree<MaybeOwningDeviceMemory> & tree,int device_ordinal)269 static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
270     const ShapeTree<MaybeOwningDeviceMemory>& tree, int device_ordinal) {
271   ShapedBuffer result(tree.shape(), device_ordinal);
272   auto it = tree.begin();
273   auto out_it = result.buffers().begin();
274   for (; it != tree.end(); ++it, ++out_it) {
275     out_it->second = it->second.AsDeviceMemoryBase();
276   }
277   return result;
278 }
279 
RunAsync(absl::Span<Shape const * const> argument_host_shapes,std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)280 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
281     absl::Span<Shape const* const> argument_host_shapes,
282     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
283   if (argument_host_shapes.size() != arguments.size()) {
284     return InvalidArgument(
285         "Number of argument host shapes not equal to number of arguments (%d "
286         "vs %d)",
287         argument_host_shapes.size(), arguments.size());
288   }
289   TF_ASSIGN_OR_RETURN(auto options_and_stream,
290                       RunHelper(argument_host_shapes, run_options));
291   se::Stream* stream = run_options.stream();
292 
293   std::shared_ptr<HloSnapshot> snapshot;
294   if (executable_->dumping_snapshot()) {
295     std::vector<ShapedBuffer> shaped_buffers;
296     std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
297     shaped_buffers.reserve(arguments.size());
298     shaped_buffer_ptrs.reserve(arguments.size());
299     for (size_t i = 0; i < arguments.size(); ++i) {
300       shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
301           arguments[i].Buffers(), stream->parent()->device_ordinal()));
302       shaped_buffer_ptrs.push_back(&shaped_buffers.back());
303     }
304 
305     snapshot =
306         DumpArguments(backend_, executable_.get(), shaped_buffer_ptrs, stream);
307   }
308 
309   TF_ASSIGN_OR_RETURN(ExecutionOutput outputs,
310                       executable_->ExecuteAsyncOnStreamWrapper(
311                           &options_and_stream.first, std::move(arguments)));
312 
313   // Transfer the outputs and save the snapshot to disk.
314   if (snapshot) {
315     DumpOutputsAndSaveSnapshot(backend_, outputs.Result(), std::move(snapshot),
316                                stream);
317   }
318 
319   return std::move(outputs);
320 }
321 
RunAsync(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)322 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
323     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
324   std::vector<const Shape*> argument_shapes;
325   argument_shapes.reserve(arguments.size());
326   for (const ExecutionInput& arg : arguments) {
327     argument_shapes.push_back(&arg.shape());
328   }
329   return RunAsync(argument_shapes, std::move(arguments), run_options);
330 }
331 
platform() const332 se::Platform* LocalClient::platform() const {
333   return local_service_->backend().platform();
334 }
335 
device_count() const336 int LocalClient::device_count() const {
337   return local_service_->backend().device_count();
338 }
339 
device_ordinal_supported(int device_ordinal) const340 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
341   return local_service_->backend().device_ordinal_supported(device_ordinal);
342 }
343 
default_device_ordinal() const344 int LocalClient::default_device_ordinal() const {
345   return local_service_->backend().default_device_ordinal();
346 }
347 
backend() const348 const Backend& LocalClient::backend() const {
349   return local_service_->backend();
350 }
351 
mutable_backend()352 Backend* LocalClient::mutable_backend() {
353   return local_service_->mutable_backend();
354 }
355 
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)356 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
357     const XlaComputation& computation,
358     const absl::Span<const Shape* const> argument_layouts,
359     const ExecutableBuildOptions& options) {
360   ExecutableBuildOptions updated_options = options;
361   if (options.device_ordinal() == -1) {
362     updated_options.set_device_ordinal(default_device_ordinal());
363     VLOG(3) << "Set device ordinal to default value of: "
364             << updated_options.device_ordinal();
365   }
366   if (options.has_device_assignment()) {
367     if (options.device_assignment().replica_count() != options.num_replicas()) {
368       return InvalidArgument(
369           "Mismatched number of replicas for device "
370           "assignment and computation (%d vs %d).\n%s",
371           options.device_assignment().replica_count(), options.num_replicas(),
372           options.device_assignment().ToString());
373     }
374     if (options.device_assignment().computation_count() !=
375         options.num_partitions()) {
376       return InvalidArgument(
377           "Mismatched number of partitions for device "
378           "assignment and computation (%d vs %d).\n%s",
379           options.device_assignment().computation_count(),
380           options.num_partitions(), options.device_assignment().ToString());
381     }
382   }
383   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
384                       local_service_->CompileExecutables(
385                           computation, argument_layouts, updated_options));
386 
387   std::vector<std::unique_ptr<LocalExecutable>> local_executables;
388   local_executables.reserve(executables.size());
389 
390   for (auto& executable : executables) {
391     local_executables.push_back(absl::make_unique<LocalExecutable>(
392         std::move(executable), local_service_->mutable_backend(),
393         updated_options));
394   }
395 
396   return std::move(local_executables);
397 }
398 
LiteralToShapedBuffer(const LiteralSlice & literal,int device_ordinal,se::DeviceMemoryAllocator * allocator)399 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
400     const LiteralSlice& literal, int device_ordinal,
401     se::DeviceMemoryAllocator* allocator) {
402   if (allocator == nullptr) {
403     allocator = backend().memory_allocator();
404   }
405   TF_ASSIGN_OR_RETURN(auto scoped_buffer,
406                       backend().transfer_manager()->AllocateScopedShapedBuffer(
407                           literal.shape(), allocator, device_ordinal));
408   TF_ASSIGN_OR_RETURN(auto stream,
409                       mutable_backend()->BorrowStream(device_ordinal));
410   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
411       stream.get(), literal, scoped_buffer));
412   return std::move(scoped_buffer);
413 }
414 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)415 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
416     const ShapedBuffer& shaped_buffer) {
417   TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
418                                        shaped_buffer.device_ordinal()));
419   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
420                                                                  shaped_buffer);
421 }
422 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)423 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
424     const GlobalDataHandle& data, int replica_number) {
425   return local_service_->GlobalDataToShapedBuffer(data, replica_number);
426 }
427 
TransferToInfeedLocal(const LiteralSlice & literal,int device_ordinal)428 Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
429                                           int device_ordinal) {
430   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
431                       backend().stream_executor(device_ordinal));
432   return backend().transfer_manager()->TransferLiteralToInfeed(executor,
433                                                                literal);
434 }
435 
TransferFromOutfeedLocal(int device_ordinal,MutableBorrowingLiteral literal)436 Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
437                                              MutableBorrowingLiteral literal) {
438   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
439                       backend().stream_executor(device_ordinal));
440   return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
441                                                                   literal);
442 }
443 
ReplicaNumberToDeviceOrdinal(int replica_number)444 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
445   return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
446 }
447 
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_ordinal)448 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
449     const ::xla::BorrowingLiteral& literal, int device_ordinal) {
450   const ::xla::Shape& shape = literal.shape();
451 
452   TF_ASSIGN_OR_RETURN(::xla::ScopedShapedBuffer shaped_buffer,
453                       backend().transfer_manager()->AllocateScopedShapedBuffer(
454                           shape, backend().memory_allocator(), device_ordinal));
455   TF_ASSIGN_OR_RETURN(auto stream,
456                       mutable_backend()->BorrowStream(device_ordinal));
457   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
458       stream.get(), literal, shaped_buffer));
459   std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
460   replicated_buffer.emplace_back(std::move(shaped_buffer));
461   ::xla::TransferToServerResponse result;
462   TF_ASSIGN_OR_RETURN(*result.mutable_data(),
463                       local_service_->RegisterReplicatedBuffers(
464                           std::move(replicated_buffer),
465                           absl::StrCat("TransferToServer literal of shape ",
466                                        ::xla::ShapeUtil::HumanString(shape))));
467 
468   return result;
469 }
470 
471 }  // namespace xla
472