1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/client.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
Client(ServiceInterface * stub)37 Client::Client(ServiceInterface* stub) : stub_(stub) {}
38 
39 Client::~Client() = default;
40 
Transfer(const GlobalData & data,const Shape * shape_with_layout)41 StatusOr<Literal> Client::Transfer(const GlobalData& data,
42                                    const Shape* shape_with_layout) {
43   TransferToClientRequest request;
44   *request.mutable_data() = data.handle();
45   if (shape_with_layout != nullptr) {
46     *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
47   }
48   TransferToClientResponse response;
49 
50   VLOG(1) << "making transfer request";
51   VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
52   Status s = stub_->TransferToClient(&request, &response);
53   VLOG(1) << "done with request";
54 
55   if (!s.ok()) {
56     return s;
57   }
58   VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
59 
60   if (!response.has_literal()) {
61     return FailedPrecondition(
62         "server provided response without a literal in "
63         "TransferToClient request");
64   }
65   return Literal::CreateFromProto(*response.mutable_literal());
66 }
67 
TransferToServer(const LiteralSlice & literal,const DeviceHandle * device_handle)68 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
69     const LiteralSlice& literal, const DeviceHandle* device_handle) {
70   TransferToServerRequest request;
71   *request.mutable_literal() = literal.ToProto();
72   if (device_handle) {
73     *request.mutable_device_handle() = *device_handle;
74   }
75   TransferToServerResponse response;
76 
77   VLOG(1) << "making transfer to server request";
78   VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
79   Status s = stub_->TransferToServer(&request, &response);
80   VLOG(1) << "done with request";
81 
82   if (!s.ok()) {
83     return s;
84   }
85   VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
86 
87   if (!response.has_data()) {
88     return FailedPrecondition(
89         "server provided response without a data handle in "
90         "TransferToServer request");
91   }
92 
93   return absl::make_unique<GlobalData>(stub_, response.data());
94 }
95 
TransferToInfeed(const LiteralSlice & literal,int64 replica_id,const DeviceHandle * device_handle)96 Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
97                                 const DeviceHandle* device_handle) {
98   TransferToInfeedRequest request;
99   *request.mutable_literal() = literal.ToProto();
100   if (device_handle) {
101     *request.mutable_device_handle() = *device_handle;
102   }
103   request.set_replica_id(replica_id);
104   TransferToInfeedResponse response;
105 
106   VLOG(1) << "making transfer to infeed request";
107   VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
108   Status s = stub_->TransferToInfeed(&request, &response);
109   VLOG(1) << "done with request";
110 
111   if (!s.ok()) {
112     return s;
113   }
114   VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
115   return Status::OK();
116 }
117 
TransferFromOutfeed(const Shape * shape_with_layout,int64 replica_id,const DeviceHandle * device_handle)118 StatusOr<Literal> Client::TransferFromOutfeed(
119     const Shape* shape_with_layout, int64 replica_id,
120     const DeviceHandle* device_handle) {
121   TransferFromOutfeedRequest request;
122   if (device_handle) {
123     *request.mutable_device_handle() = *device_handle;
124   }
125   request.set_replica_id(replica_id);
126   if (shape_with_layout != nullptr) {
127     *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
128   }
129   TransferFromOutfeedResponse response;
130 
131   VLOG(1) << "making transfer from outfeed request";
132   VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
133   Status s = stub_->TransferFromOutfeed(&request, &response);
134   VLOG(1) << "done with request";
135 
136   if (!s.ok()) {
137     return s;
138   }
139   VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
140 
141   if (!response.has_literal()) {
142     return FailedPrecondition(
143         "server provided response without a literal in "
144         "TransferToClient request");
145   }
146 
147   return Literal::CreateFromProto(response.literal());
148 }
149 
ResetDevice()150 Status Client::ResetDevice() {
151   ResetDeviceRequest request;
152   ResetDeviceResponse response;
153 
154   VLOG(1) << "making reset device request";
155   VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
156   Status s = stub_->ResetDevice(&request, &response);
157   VLOG(1) << "done with request";
158 
159   if (!s.ok()) {
160     return s;
161   }
162   VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
163   return Status::OK();
164 }
165 
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)166 StatusOr<Literal> Client::ExecuteAndTransfer(
167     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
168     const ExecutionOptions* execution_options,
169     ExecutionProfile* execution_profile) {
170   TF_ASSIGN_OR_RETURN(
171       std::unique_ptr<GlobalData> data,
172       Execute(computation, arguments, execution_options, execution_profile));
173 
174   absl::optional<Shape> shape_with_output_layout;
175   if (execution_options && execution_options->has_shape_with_output_layout()) {
176     shape_with_output_layout =
177         Shape(execution_options->shape_with_output_layout());
178   }
179   return Transfer(*data, shape_with_output_layout.has_value()
180                              ? &(*shape_with_output_layout)
181                              : nullptr);
182 }
183 
ComputeConstant(const XlaComputation & computation,const Layout * output_layout) const184 StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
185                                           const Layout* output_layout) const {
186   ComputeConstantGraphRequest request;
187   *request.mutable_computation() = computation.proto();
188   if (output_layout != nullptr) {
189     *request.mutable_output_layout() = output_layout->ToProto();
190   }
191 
192   ComputeConstantResponse response;
193 
194   VLOG(2) << "making compute-constant-graph request";
195   Status s = stub_->ComputeConstantGraph(&request, &response);
196   VLOG(2) << "done with request";
197 
198   if (!s.ok()) {
199     return s;
200   }
201 
202   VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
203 
204   if (!response.has_literal()) {
205     return InternalError(
206         "no computed literal in the provided response in ComputeConstantGraph "
207         "request");
208   }
209   return Literal::CreateFromProto(response.literal());
210 }
211 
LoadSnapshot(const HloSnapshot & module)212 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
213   TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
214   return XlaComputation(module.hlo().hlo_module());
215 }
216 
Compile(const XlaComputation & computation,absl::Span<const Shape> argument_shapes,const ExecutionOptions * execution_options)217 StatusOr<ExecutionHandle> Client::Compile(
218     const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
219     const ExecutionOptions* execution_options) {
220   CompileRequest request;
221   *request.mutable_computation() = computation.proto();
222 
223   if (execution_options == nullptr) {
224     *request.mutable_execution_options() = CreateDefaultExecutionOptions();
225   } else {
226     *request.mutable_execution_options() = *execution_options;
227   }
228   if (request.execution_options().device_handles_size() > 1) {
229     return InvalidArgument(
230         "Compiling with multiple device handles is not supported. Use "
231         "'Execute' instead.");
232   }
233 
234   // The argument shapes affect how the computation is compiled.
235   for (const auto& arg_shape : argument_shapes) {
236     *request.add_input_shape_with_layout() = arg_shape.ToProto();
237   }
238 
239   CompileResponse response;
240   VLOG(1) << "making compile request: " << request.ShortDebugString();
241   Status s = stub_->Compile(&request, &response);
242   VLOG(1) << "done with request";
243 
244   if (!s.ok()) {
245     return s;
246   }
247   TF_RET_CHECK(response.has_handle());
248   return response.handle();
249 }
250 
Execute(const ExecutionHandle & handle,absl::Span<GlobalData * const> arguments,ExecutionProfile * execution_profile)251 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
252     const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
253     ExecutionProfile* execution_profile) {
254   ExecuteRequest request;
255   *request.mutable_handle() = handle;
256   for (GlobalData* argument : arguments) {
257     CHECK(argument != nullptr) << "Argument pointers must not be null.";
258     *request.add_arguments() = argument->handle();
259   }
260 
261   ExecuteResponse response;
262   VLOG(1) << "making execute request: " << request.ShortDebugString();
263   Status s = stub_->Execute(&request, &response);
264   VLOG(1) << "done with request";
265 
266   if (!s.ok()) {
267     return s;
268   }
269 
270   if (execution_profile != nullptr) {
271     *execution_profile = response.profile();
272   }
273 
274   return absl::make_unique<GlobalData>(stub_, response.output());
275 }
276 
Execute(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)277 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
278     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
279     const ExecutionOptions* execution_options,
280     ExecutionProfile* execution_profile) {
281   // Create an ExecutionOptions if necessary, or set its DeviceHandles.
282   absl::optional<ExecutionOptions> options_storage;
283   if (!execution_options || execution_options->device_handles().empty()) {
284     if (execution_options) {
285       options_storage.emplace(*execution_options);
286     } else {
287       options_storage.emplace(CreateDefaultExecutionOptions());
288     }
289     execution_options = &*options_storage;
290 
291     TF_ASSIGN_OR_RETURN(auto device_handles,
292                         GetDeviceHandles(/*device_count=*/1));
293     TF_RET_CHECK(!device_handles.empty());
294     *options_storage->add_device_handles() = std::move(device_handles[0]);
295   }
296 
297   std::vector<XlaComputationInstance> computation_instances = {
298       XlaComputationInstance{
299           computation,
300           std::vector<GlobalData*>(arguments.begin(), arguments.end()),
301           *execution_options, execution_profile}};
302 
303   // Instead of invoking Compile() and Execute(), invoke
304   // Service::ExecuteParallel() to execute our one computation.  Compile()
305   // caches the executable forever, which isn't what we want.
306   VLOG(1) << "Making ExecuteParallel request: "
307           << execution_options->DebugString();
308   TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
309   VLOG(1) << "ExecuteParallel request done.";
310 
311   // The result selection is a bit hacky, but better than assuming it is
312   // device 0.
313   //
314   // TODO(b/118493728): Allow Execute to return one result per computation.
315   for (int64 i = 0; i < results.size(); i++) {
316     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
317     if (!ShapeUtil::IsEmptyTuple(shape)) {
318       VLOG(3) << "Fetching result from device " << i << ": "
319               << ShapeUtil::HumanString(shape);
320       return std::move(results[i]);
321     }
322   }
323   TF_RET_CHECK(!results.empty());
324   VLOG(1) << "Defaulting to device 0 result";
325   return std::move(results[0]);
326 }
327 
ExecuteParallel(absl::Span<const XlaComputationInstance> computations)328 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
329     absl::Span<const XlaComputationInstance> computations) {
330   ExecuteGraphParallelRequest request;
331 
332   for (const XlaComputationInstance& computation : computations) {
333     ExecuteGraphRequest single_request;
334     *single_request.mutable_computation() = computation.computation.proto();
335     for (GlobalData* argument : computation.arguments) {
336       *single_request.add_arguments() = argument->handle();
337     }
338     *single_request.mutable_execution_options() = computation.execution_options;
339     *request.add_requests() = single_request;
340   }
341 
342   ExecuteParallelResponse response;
343   VLOG(1) << "making execute-graph-parallel request: "
344           << request.ShortDebugString();
345   Status s = stub_->ExecuteGraphParallel(&request, &response);
346   VLOG(1) << "done with request";
347 
348   if (!s.ok()) {
349     return s;
350   }
351 
352   std::vector<std::unique_ptr<GlobalData>> outputs;
353   for (size_t i = 0; i < response.responses_size(); ++i) {
354     outputs.push_back(
355         absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
356     if (i < computations.size() &&
357         computations[i].execution_profile != nullptr) {
358       *computations[i].execution_profile = response.responses(i).profile();
359     }
360   }
361 
362   return std::move(outputs);
363 }
364 
GetDeviceHandles(int64 device_count)365 StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
366     int64 device_count) {
367   if (device_count < 1) {
368     return InvalidArgument("device_count must be greater than 0");
369   }
370   GetDeviceHandlesRequest request;
371   request.set_device_count(device_count);
372 
373   GetDeviceHandlesResponse response;
374   VLOG(1) << "making get device request: " << request.ShortDebugString();
375   Status s = stub_->GetDeviceHandles(&request, &response);
376   VLOG(1) << "done with request";
377 
378   if (!s.ok()) {
379     return s;
380   }
381 
382   std::vector<DeviceHandle> device_handles;
383   for (const DeviceHandle& device_handle : response.device_handles()) {
384     device_handles.push_back(device_handle);
385   }
386 
387   return device_handles;
388 }
389 
Unregister(const GlobalData & data)390 Status Client::Unregister(const GlobalData& data) {
391   UnregisterRequest request;
392   *request.add_data() = data.handle();
393   UnregisterResponse response;
394 
395   VLOG(1) << "making unregister request";
396   Status s = stub_->Unregister(&request, &response);
397   VLOG(1) << "done with request";
398 
399   return s;
400 }
401 
DeconstructTuple(const GlobalData & data)402 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
403     const GlobalData& data) {
404   DeconstructTupleRequest request;
405   *request.mutable_tuple_handle() = data.handle();
406   DeconstructTupleResponse response;
407 
408   VLOG(1) << "making DestructTuple request";
409   Status s = stub_->DeconstructTuple(&request, &response);
410   VLOG(1) << "done with request";
411 
412   if (!s.ok()) {
413     return s;
414   }
415 
416   std::vector<std::unique_ptr<GlobalData>> handles;
417   for (auto& handle : response.element_handles()) {
418     handles.push_back(absl::make_unique<GlobalData>(stub_, handle));
419   }
420   return std::move(handles);
421 }
422 
GetComputationStats(const XlaComputation & computation,const DebugOptions & debug_options) const423 StatusOr<ComputationStats> Client::GetComputationStats(
424     const XlaComputation& computation,
425     const DebugOptions& debug_options) const {
426   ComputationGraphStatsRequest request;
427 
428   // TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
429   *request.mutable_computation() = computation.proto();
430   *request.mutable_debug_options() = debug_options;
431   ComputationStatsResponse response;
432 
433   VLOG(1) << "making computation graph stats request";
434   Status s = stub_->GetComputationGraphStats(&request, &response);
435   VLOG(1) << "done with request";
436 
437   if (!s.ok()) {
438     return s;
439   }
440   CHECK(response.has_stats());
441   return response.stats();
442 }
443 
GetComputationShape(const XlaComputation & computation)444 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
445     const XlaComputation& computation) {
446   TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
447   return absl::make_unique<ProgramShape>(result);
448 }
449 
GetShape(const GlobalData & data)450 StatusOr<Shape> Client::GetShape(const GlobalData& data) {
451   GetShapeRequest request;
452   *request.mutable_data() = data.handle();
453   GetShapeResponse response;
454 
455   VLOG(1) << "making get shape request";
456   Status s = stub_->GetShape(&request, &response);
457   VLOG(1) << "done with request";
458 
459   if (!s.ok()) {
460     return s;
461   }
462 
463   return Shape(response.shape());
464 }
465 
ExecutionStatsAsString(const XlaComputation & computation,const ExecutionProfile & profile)466 StatusOr<string> Client::ExecutionStatsAsString(
467     const XlaComputation& computation, const ExecutionProfile& profile) {
468   TF_ASSIGN_OR_RETURN(
469       auto computation_stats,
470       GetComputationStats(computation, GetDebugOptionsFromFlags()));
471   int64 total_flops =
472       computation_stats.flop_count() + computation_stats.transcendental_count();
473   if (profile.compute_time_ns() > 0) {
474     int64 nanoseconds = profile.compute_time_ns();
475     int64 cycle_count = profile.compute_cycle_count();
476     double gflops = total_flops / nanoseconds;
477     return absl::StrCat(
478         "[Execution Statistics] flop count: ", computation_stats.flop_count(),
479         ", transcendental count: ", computation_stats.transcendental_count(),
480         ", compute execution time: ", nanoseconds, " nsec",
481         ", compute cycles: ", cycle_count, ", performance: ", gflops,
482         "gflop/s");
483   }
484   return string("[Execution Statistics] not available.");
485 }
486 
CreateChannelHandleByType(ChannelHandle::ChannelType type)487 StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
488     ChannelHandle::ChannelType type) {
489   CreateChannelHandleRequest request;
490   request.set_channel_type(type);
491   CreateChannelHandleResponse response;
492 
493   VLOG(1) << "making create channel handle request";
494   Status s = stub_->CreateChannelHandle(&request, &response);
495   VLOG(1) << "done with request";
496 
497   if (!s.ok()) {
498     return s;
499   }
500 
501   return response.channel();
502 }
503 
CreateChannelHandle()504 StatusOr<ChannelHandle> Client::CreateChannelHandle() {
505   return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
506 }
507 
CreateHostToDeviceChannelHandle()508 StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
509   return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
510 }
511 
CreateDeviceToHostChannelHandle()512 StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
513   return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
514 }
515 
516 }  // namespace xla
517