1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/local_client.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 
30 using xla::source_map_util::InvalidParameterArgument;
31 
32 namespace xla {
33 
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36                                                 Backend* backend) {
37   if (device_ordinal < 0) {
38     device_ordinal = backend->default_device_ordinal();
39   }
40   return backend->BorrowStream(device_ordinal);
41 }
42 }  // namespace
43 
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45                                  Backend* backend,
46                                  ExecutableBuildOptions build_options)
47     : executable_(std::move(executable)),
48       backend_(backend),
49       build_options_(std::move(build_options)) {
50   CHECK_GE(build_options_.device_ordinal(), 0)
51       << "Must have a valid device ordinal that the executable was built for.";
52 }
53 
ValidateExecutionOptions(const absl::Span<const ShapedBuffer * const> arguments,const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55     const absl::Span<const ShapedBuffer* const> arguments,
56     const ExecutableRunOptions& run_options, const Backend& backend) {
57   const ComputationLayout& computation_layout =
58       executable_->module_config().entry_computation_layout();
59 
60   // Check argument number, shapes, and layouts.
61   if (arguments.size() != computation_layout.parameter_count()) {
62     return InvalidArgument(
63         "invalid number of arguments for computation: expected %d, got %u",
64         computation_layout.parameter_count(), arguments.size());
65   }
66   for (int i = 0; i < arguments.size(); ++i) {
67     if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
68             arguments[i]->on_host_shape())) {
69       return InvalidParameterArgument(
70           executable_.get(), i,
71           "Argument does not match host shape or layout of computation "
72           "parameter "
73           "%d: want %s, got %s",
74           i,
75           ShapeUtil::HumanStringWithLayout(
76               computation_layout.parameter_layout(i).shape()),
77           ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape()));
78     }
79   }
80 
81   if (run_options.stream() != nullptr) {
82     if (!run_options.stream()->ok()) {
83       return InvalidArgument("stream is uninitialized or in an error state");
84     }
85 
86     // Check stream matches service platform.
87     const se::Platform* stream_platform =
88         run_options.stream()->parent()->platform();
89     if (stream_platform != backend_->platform()) {
90       return InvalidArgument(
91           "stream is for platform %s, but service targets platform %s",
92           stream_platform->Name(), backend_->platform()->Name());
93     }
94 
95     // Cannot specify device_ordinal with a stream. The stream determines these
96     // values.
97     if (run_options.device_ordinal() != -1) {
98       return InvalidArgument(
99           "cannot set both device ordinal and stream options in "
100           "ExecutableRunOptions; the stream determines the device ordinal");
101     }
102   }
103 
104   // Verify that the device the executable was built for is equivalent
105   // to the device it will run on.
106   int run_device_ordinal = run_options.device_ordinal();
107   if (run_device_ordinal == -1) {
108     run_device_ordinal = run_options.stream() != nullptr
109                              ? run_options.stream()->parent()->device_ordinal()
110                              : backend_->default_device_ordinal();
111   }
112   TF_ASSIGN_OR_RETURN(bool devices_equivalent,
113                       backend_->devices_equivalent(
114                           run_device_ordinal, build_options_.device_ordinal()));
115   if (!devices_equivalent) {
116     TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
117                         backend_->stream_executor(run_device_ordinal));
118     TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
119                         backend_->stream_executor(build_device_ordinal()));
120     return InvalidArgument(
121         "executable is built for device %s of type \"%s\"; cannot run it on "
122         "device %s of type \"%s\"",
123         backend_->device_name(build_device_ordinal()),
124         build_executor->GetDeviceDescription().name(),
125         backend_->device_name(run_device_ordinal),
126         run_executor->GetDeviceDescription().name());
127   }
128 
129   if (!run_options.allocator()) {
130     return InvalidArgument("an allocator must be provided to ExecuteLocally");
131   }
132 
133   if (run_options.allocator()->platform() != backend.platform()) {
134     return InvalidArgument(
135         "allocator platform (%s) does not match service platform (%s)",
136         run_options.allocator()->platform()->Name(),
137         backend.platform()->Name());
138   }
139 
140   return Status::OK();
141 }
142 
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)143 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
144     const absl::Span<const ShapedBuffer* const> arguments,
145     ExecutableRunOptions run_options) {
146   TF_RETURN_IF_ERROR(
147       ValidateExecutionOptions(arguments, run_options, *backend_));
148 
149   StreamPool::Ptr stream;
150   if (run_options.stream() == nullptr) {
151     // NB!  The lifetime of `stream` needs to match the lifetime of
152     // `actual_options` (otherwise we will end up using a returned stream in
153     // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
154     // scope.
155     TF_ASSIGN_OR_RETURN(
156         stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
157     run_options.set_stream(stream.get());
158   }
159   if (run_options.allocator() == nullptr) {
160     run_options.set_allocator(backend_->memory_allocator());
161   }
162 
163   // For local client execution on CPU backends:
164   // *) The thread pool used for eigen CPU ops is from
165   //    ExecutableRunOptions.eigen_intra_op_thread_pool.
166   // *) The thread pool used for XLA CPU ops is from
167   //    backend_->eigen_intra_op_thread_pool().
168   ServiceExecutableRunOptions service_options(run_options,
169                                               backend_->StreamBorrower());
170 
171   if (executable_->dumping_snapshot()) {
172     return ExecuteAndDump(&service_options, arguments);
173   }
174   return executable_->ExecuteOnStreamWrapper(
175       &service_options, run_options.execution_profile(), arguments);
176 }
177 
ExecuteAndDump(const ServiceExecutableRunOptions * run_options,const absl::Span<const ShapedBuffer * const> arguments)178 StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
179     const ServiceExecutableRunOptions* run_options,
180     const absl::Span<const ShapedBuffer* const> arguments) {
181   executable_->hlo_snapshot()->set_execution_platform(
182       backend_->platform()->Name());
183   TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
184   TF_ASSIGN_OR_RETURN(
185       ScopedShapedBuffer result,
186       executable_->ExecuteOnStream(run_options, arguments,
187                                    /*hlo_execution_profile=*/nullptr));
188   TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
189   DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot());
190   return std::move(result);
191 }
192 
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,HloSnapshot * hlo_snapshot)193 Status LocalExecutable::RecordArguments(
194     const absl::Span<const ShapedBuffer* const> arguments,
195     HloSnapshot* hlo_snapshot) {
196   hlo_snapshot->clear_arguments();
197   for (const ShapedBuffer* argument : arguments) {
198     TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
199     *hlo_snapshot->add_arguments() = literal.ToProto();
200   }
201   return Status::OK();
202 }
203 
RecordResult(const ShapedBuffer * result,HloSnapshot * hlo_snapshot)204 Status LocalExecutable::RecordResult(const ShapedBuffer* result,
205                                      HloSnapshot* hlo_snapshot) {
206   hlo_snapshot->clear_result();
207   TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
208   *hlo_snapshot->mutable_result() = literal.ToProto();
209   return Status::OK();
210 }
211 
LiteralFromShapedBuffer(const ShapedBuffer & shaped_buffer)212 StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
213     const ShapedBuffer& shaped_buffer) {
214   TF_ASSIGN_OR_RETURN(auto stream,
215                       backend_->BorrowStream(shaped_buffer.device_ordinal()));
216   return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(),
217                                                                  shaped_buffer);
218 }
219 
platform() const220 se::Platform* LocalClient::platform() const {
221   return local_service_->backend().platform();
222 }
223 
device_count() const224 int LocalClient::device_count() const {
225   return local_service_->backend().device_count();
226 }
227 
device_ordinal_supported(int device_ordinal) const228 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
229   return local_service_->backend().device_ordinal_supported(device_ordinal);
230 }
231 
default_device_ordinal() const232 int LocalClient::default_device_ordinal() const {
233   return local_service_->backend().default_device_ordinal();
234 }
235 
backend() const236 const Backend& LocalClient::backend() const {
237   return local_service_->backend();
238 }
239 
mutable_backend()240 Backend* LocalClient::mutable_backend() {
241   return local_service_->mutable_backend();
242 }
243 
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)244 StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
245     const XlaComputation& computation,
246     const absl::Span<const Shape* const> argument_layouts,
247     const ExecutableBuildOptions& options) {
248   ExecutableBuildOptions updated_options = options;
249   if (options.device_ordinal() == -1) {
250     updated_options.set_device_ordinal(default_device_ordinal());
251     VLOG(3) << "Set device ordinal to default value of: "
252             << updated_options.device_ordinal();
253   }
254   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
255                       local_service_->CompileExecutable(
256                           computation, argument_layouts, updated_options));
257   return absl::WrapUnique(new LocalExecutable(std::move(executable),
258                                               local_service_->mutable_backend(),
259                                               updated_options));
260 }
261 
LiteralToShapedBuffer(const Literal & literal,int device_ordinal,DeviceMemoryAllocator * allocator)262 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
263     const Literal& literal, int device_ordinal,
264     DeviceMemoryAllocator* allocator) {
265   if (allocator == nullptr) {
266     allocator = backend().memory_allocator();
267   }
268   TF_ASSIGN_OR_RETURN(auto scoped_buffer,
269                       backend().transfer_manager()->AllocateScopedShapedBuffer(
270                           literal.shape(), allocator, device_ordinal));
271   TF_ASSIGN_OR_RETURN(auto stream,
272                       mutable_backend()->BorrowStream(device_ordinal));
273   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
274       stream.get(), literal, scoped_buffer));
275   return std::move(scoped_buffer);
276 }
277 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)278 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
279     const ShapedBuffer& shaped_buffer) {
280   TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
281                                        shaped_buffer.device_ordinal()));
282   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
283                                                                  shaped_buffer);
284 }
285 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)286 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
287     const GlobalDataHandle& data, int replica_number) {
288   return local_service_->GlobalDataToShapedBuffer(data, replica_number);
289 }
290 
TransferToInfeedLocal(const Literal & literal,int device_ordinal)291 Status LocalClient::TransferToInfeedLocal(const Literal& literal,
292                                           int device_ordinal) {
293   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
294                       backend().stream_executor(device_ordinal));
295   return backend().transfer_manager()->TransferLiteralToInfeed(executor,
296                                                                literal);
297 }
298 
TransferFromOutfeedLocal(const Shape & shape,int device_ordinal)299 StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
300                                                         int device_ordinal) {
301   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
302                       backend().stream_executor(device_ordinal));
303   auto literal = Literal::CreateFromShape(shape);
304   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
305       executor, shape, &literal));
306   return std::move(literal);
307 }
308 
ReplicaNumberToDeviceOrdinal(int replica_number)309 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
310   return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
311 }
312 
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_oridinal)313 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
314     const ::xla::BorrowingLiteral& literal, int device_oridinal) {
315   const ::xla::Shape& shape = literal.shape();
316 
317   TF_ASSIGN_OR_RETURN(
318       ::xla::ScopedShapedBuffer shaped_buffer,
319       backend().transfer_manager()->AllocateScopedShapedBuffer(
320           shape, backend().memory_allocator(), device_oridinal));
321   TF_ASSIGN_OR_RETURN(auto stream,
322                       mutable_backend()->BorrowStream(device_oridinal));
323   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
324       stream.get(), literal, shaped_buffer));
325   std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
326   replicated_buffer.emplace_back(std::move(shaped_buffer));
327   ::xla::TransferToServerResponse result;
328   TF_ASSIGN_OR_RETURN(*result.mutable_data(),
329                       local_service_->RegisterReplicatedBuffers(
330                           std::move(replicated_buffer),
331                           absl::StrCat("TransferToServer literal of shape ",
332                                        ::xla::ShapeUtil::HumanString(shape))));
333 
334   return result;
335 }
336 
337 }  // namespace xla
338