1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/local_client.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29
30 using xla::source_map_util::InvalidParameterArgument;
31
32 namespace xla {
33
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36 Backend* backend) {
37 if (device_ordinal < 0) {
38 device_ordinal = backend->default_device_ordinal();
39 }
40 return backend->BorrowStream(device_ordinal);
41 }
42 } // namespace
43
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45 Backend* backend,
46 ExecutableBuildOptions build_options)
47 : executable_(std::move(executable)),
48 backend_(backend),
49 build_options_(std::move(build_options)) {
50 CHECK_GE(build_options_.device_ordinal(), 0)
51 << "Must have a valid device ordinal that the executable was built for.";
52 }
53
ValidateExecutionOptions(const absl::Span<const ShapedBuffer * const> arguments,const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55 const absl::Span<const ShapedBuffer* const> arguments,
56 const ExecutableRunOptions& run_options, const Backend& backend) {
57 const ComputationLayout& computation_layout =
58 executable_->module_config().entry_computation_layout();
59
60 // Check argument number, shapes, and layouts.
61 if (arguments.size() != computation_layout.parameter_count()) {
62 return InvalidArgument(
63 "invalid number of arguments for computation: expected %d, got %u",
64 computation_layout.parameter_count(), arguments.size());
65 }
66 for (int i = 0; i < arguments.size(); ++i) {
67 if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
68 arguments[i]->on_host_shape())) {
69 return InvalidParameterArgument(
70 executable_.get(), i,
71 "Argument does not match host shape or layout of computation "
72 "parameter "
73 "%d: want %s, got %s",
74 i,
75 ShapeUtil::HumanStringWithLayout(
76 computation_layout.parameter_layout(i).shape()),
77 ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape()));
78 }
79 }
80
81 if (run_options.stream() != nullptr) {
82 if (!run_options.stream()->ok()) {
83 return InvalidArgument("stream is uninitialized or in an error state");
84 }
85
86 // Check stream matches service platform.
87 const se::Platform* stream_platform =
88 run_options.stream()->parent()->platform();
89 if (stream_platform != backend_->platform()) {
90 return InvalidArgument(
91 "stream is for platform %s, but service targets platform %s",
92 stream_platform->Name(), backend_->platform()->Name());
93 }
94
95 // Cannot specify device_ordinal with a stream. The stream determines these
96 // values.
97 if (run_options.device_ordinal() != -1) {
98 return InvalidArgument(
99 "cannot set both device ordinal and stream options in "
100 "ExecutableRunOptions; the stream determines the device ordinal");
101 }
102 }
103
104 // Verify that the device the executable was built for is equivalent
105 // to the device it will run on.
106 int run_device_ordinal = run_options.device_ordinal();
107 if (run_device_ordinal == -1) {
108 run_device_ordinal = run_options.stream() != nullptr
109 ? run_options.stream()->parent()->device_ordinal()
110 : backend_->default_device_ordinal();
111 }
112 TF_ASSIGN_OR_RETURN(bool devices_equivalent,
113 backend_->devices_equivalent(
114 run_device_ordinal, build_options_.device_ordinal()));
115 if (!devices_equivalent) {
116 TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
117 backend_->stream_executor(run_device_ordinal));
118 TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
119 backend_->stream_executor(build_device_ordinal()));
120 return InvalidArgument(
121 "executable is built for device %s of type \"%s\"; cannot run it on "
122 "device %s of type \"%s\"",
123 backend_->device_name(build_device_ordinal()),
124 build_executor->GetDeviceDescription().name(),
125 backend_->device_name(run_device_ordinal),
126 run_executor->GetDeviceDescription().name());
127 }
128
129 if (!run_options.allocator()) {
130 return InvalidArgument("an allocator must be provided to ExecuteLocally");
131 }
132
133 if (run_options.allocator()->platform() != backend.platform()) {
134 return InvalidArgument(
135 "allocator platform (%s) does not match service platform (%s)",
136 run_options.allocator()->platform()->Name(),
137 backend.platform()->Name());
138 }
139
140 return Status::OK();
141 }
142
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)143 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
144 const absl::Span<const ShapedBuffer* const> arguments,
145 ExecutableRunOptions run_options) {
146 TF_RETURN_IF_ERROR(
147 ValidateExecutionOptions(arguments, run_options, *backend_));
148
149 StreamPool::Ptr stream;
150 if (run_options.stream() == nullptr) {
151 // NB! The lifetime of `stream` needs to match the lifetime of
152 // `actual_options` (otherwise we will end up using a returned stream in
153 // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
154 // scope.
155 TF_ASSIGN_OR_RETURN(
156 stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
157 run_options.set_stream(stream.get());
158 }
159 if (run_options.allocator() == nullptr) {
160 run_options.set_allocator(backend_->memory_allocator());
161 }
162
163 // For local client execution on CPU backends:
164 // *) The thread pool used for eigen CPU ops is from
165 // ExecutableRunOptions.eigen_intra_op_thread_pool.
166 // *) The thread pool used for XLA CPU ops is from
167 // backend_->eigen_intra_op_thread_pool().
168 ServiceExecutableRunOptions service_options(run_options,
169 backend_->StreamBorrower());
170
171 if (executable_->dumping_snapshot()) {
172 return ExecuteAndDump(&service_options, arguments);
173 }
174 return executable_->ExecuteOnStreamWrapper(
175 &service_options, run_options.execution_profile(), arguments);
176 }
177
ExecuteAndDump(const ServiceExecutableRunOptions * run_options,const absl::Span<const ShapedBuffer * const> arguments)178 StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
179 const ServiceExecutableRunOptions* run_options,
180 const absl::Span<const ShapedBuffer* const> arguments) {
181 executable_->hlo_snapshot()->set_execution_platform(
182 backend_->platform()->Name());
183 TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
184 TF_ASSIGN_OR_RETURN(
185 ScopedShapedBuffer result,
186 executable_->ExecuteOnStream(run_options, arguments,
187 /*hlo_execution_profile=*/nullptr));
188 TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
189 DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot());
190 return std::move(result);
191 }
192
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,HloSnapshot * hlo_snapshot)193 Status LocalExecutable::RecordArguments(
194 const absl::Span<const ShapedBuffer* const> arguments,
195 HloSnapshot* hlo_snapshot) {
196 hlo_snapshot->clear_arguments();
197 for (const ShapedBuffer* argument : arguments) {
198 TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
199 *hlo_snapshot->add_arguments() = literal.ToProto();
200 }
201 return Status::OK();
202 }
203
RecordResult(const ShapedBuffer * result,HloSnapshot * hlo_snapshot)204 Status LocalExecutable::RecordResult(const ShapedBuffer* result,
205 HloSnapshot* hlo_snapshot) {
206 hlo_snapshot->clear_result();
207 TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
208 *hlo_snapshot->mutable_result() = literal.ToProto();
209 return Status::OK();
210 }
211
LiteralFromShapedBuffer(const ShapedBuffer & shaped_buffer)212 StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
213 const ShapedBuffer& shaped_buffer) {
214 TF_ASSIGN_OR_RETURN(auto stream,
215 backend_->BorrowStream(shaped_buffer.device_ordinal()));
216 return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(),
217 shaped_buffer);
218 }
219
platform() const220 se::Platform* LocalClient::platform() const {
221 return local_service_->backend().platform();
222 }
223
device_count() const224 int LocalClient::device_count() const {
225 return local_service_->backend().device_count();
226 }
227
device_ordinal_supported(int device_ordinal) const228 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
229 return local_service_->backend().device_ordinal_supported(device_ordinal);
230 }
231
default_device_ordinal() const232 int LocalClient::default_device_ordinal() const {
233 return local_service_->backend().default_device_ordinal();
234 }
235
backend() const236 const Backend& LocalClient::backend() const {
237 return local_service_->backend();
238 }
239
mutable_backend()240 Backend* LocalClient::mutable_backend() {
241 return local_service_->mutable_backend();
242 }
243
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)244 StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
245 const XlaComputation& computation,
246 const absl::Span<const Shape* const> argument_layouts,
247 const ExecutableBuildOptions& options) {
248 ExecutableBuildOptions updated_options = options;
249 if (options.device_ordinal() == -1) {
250 updated_options.set_device_ordinal(default_device_ordinal());
251 VLOG(3) << "Set device ordinal to default value of: "
252 << updated_options.device_ordinal();
253 }
254 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
255 local_service_->CompileExecutable(
256 computation, argument_layouts, updated_options));
257 return absl::WrapUnique(new LocalExecutable(std::move(executable),
258 local_service_->mutable_backend(),
259 updated_options));
260 }
261
LiteralToShapedBuffer(const Literal & literal,int device_ordinal,DeviceMemoryAllocator * allocator)262 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
263 const Literal& literal, int device_ordinal,
264 DeviceMemoryAllocator* allocator) {
265 if (allocator == nullptr) {
266 allocator = backend().memory_allocator();
267 }
268 TF_ASSIGN_OR_RETURN(auto scoped_buffer,
269 backend().transfer_manager()->AllocateScopedShapedBuffer(
270 literal.shape(), allocator, device_ordinal));
271 TF_ASSIGN_OR_RETURN(auto stream,
272 mutable_backend()->BorrowStream(device_ordinal));
273 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
274 stream.get(), literal, scoped_buffer));
275 return std::move(scoped_buffer);
276 }
277
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)278 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
279 const ShapedBuffer& shaped_buffer) {
280 TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
281 shaped_buffer.device_ordinal()));
282 return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
283 shaped_buffer);
284 }
285
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)286 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
287 const GlobalDataHandle& data, int replica_number) {
288 return local_service_->GlobalDataToShapedBuffer(data, replica_number);
289 }
290
TransferToInfeedLocal(const Literal & literal,int device_ordinal)291 Status LocalClient::TransferToInfeedLocal(const Literal& literal,
292 int device_ordinal) {
293 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
294 backend().stream_executor(device_ordinal));
295 return backend().transfer_manager()->TransferLiteralToInfeed(executor,
296 literal);
297 }
298
TransferFromOutfeedLocal(const Shape & shape,int device_ordinal)299 StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
300 int device_ordinal) {
301 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
302 backend().stream_executor(device_ordinal));
303 auto literal = Literal::CreateFromShape(shape);
304 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
305 executor, shape, &literal));
306 return std::move(literal);
307 }
308
ReplicaNumberToDeviceOrdinal(int replica_number)309 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
310 return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
311 }
312
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_oridinal)313 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
314 const ::xla::BorrowingLiteral& literal, int device_oridinal) {
315 const ::xla::Shape& shape = literal.shape();
316
317 TF_ASSIGN_OR_RETURN(
318 ::xla::ScopedShapedBuffer shaped_buffer,
319 backend().transfer_manager()->AllocateScopedShapedBuffer(
320 shape, backend().memory_allocator(), device_oridinal));
321 TF_ASSIGN_OR_RETURN(auto stream,
322 mutable_backend()->BorrowStream(device_oridinal));
323 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
324 stream.get(), literal, shaped_buffer));
325 std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
326 replicated_buffer.emplace_back(std::move(shaped_buffer));
327 ::xla::TransferToServerResponse result;
328 TF_ASSIGN_OR_RETURN(*result.mutable_data(),
329 local_service_->RegisterReplicatedBuffers(
330 std::move(replicated_buffer),
331 absl::StrCat("TransferToServer literal of shape ",
332 ::xla::ShapeUtil::HumanString(shape))));
333
334 return result;
335 }
336
337 } // namespace xla
338