1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/local_computation_builder.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/qr.h"
26 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
27 #include "tensorflow/compiler/xla/client/lib/svd.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/executable_run_options.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/computation_placer.h"
34 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
35 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/gtl/array_slice.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 namespace swig {
47 
RegisterCpuCustomCallTarget(const string & fn_name,PyObject * capsule)48 Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
49   const char* name = "xla._CPU_CUSTOM_CALL_TARGET";
50   if (!PyCapsule_IsValid(capsule, name)) {
51     return InvalidArgument(
52         "Argument to RegisterCpuCustomCallTargetRegistry was not a "
53         "xla._CPU_CUSTOM_CALL_TARGET capsule.");
54   }
55   void* fn_ptr = PyCapsule_GetPointer(capsule, name);
56   CHECK(fn_ptr != nullptr);
57   cpu::CustomCallTargetRegistry::Global()->Register(
58       std::string(fn_name.begin(), fn_name.end()), fn_ptr);
59   return Status::OK();
60 }
61 
LocalClient(xla::LocalClient * client)62 LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {}
63 
Get(const string & platform_name)64 /* static */ StatusOr<LocalClient> LocalClient::Get(
65     const string& platform_name) {
66   TF_ASSIGN_OR_RETURN(se::Platform * platform,
67                       PlatformUtil::GetPlatform(platform_name));
68   if (platform->VisibleDeviceCount() <= 0) {
69     return InvalidArgument("Platform %s has no visible devices.",
70                            platform_name);
71   }
72   LocalClientOptions options;
73   options.set_platform(platform);
74   TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
75                       ClientLibrary::GetOrCreateLocalClient(options));
76   CHECK(client != nullptr);
77   return LocalClient(client);
78 }
79 
80 // Returns the number of devices known to the XLA client.
DeviceCount() const81 int LocalClient::DeviceCount() const { return client_->device_count(); }
82 
TransferToInfeed(const Literal & literal,int device_ordinal)83 Status LocalClient::TransferToInfeed(const Literal& literal,
84                                      int device_ordinal) {
85   VLOG(1) << "Infeeding literal to device " << device_ordinal
86           << "; shape: " << literal.shape();
87   return client_->TransferToInfeed(literal, device_ordinal);
88 }
89 
TransferFromOutfeed(const Shape & shape,int device_ordinal)90 StatusOr<Literal> LocalClient::TransferFromOutfeed(const Shape& shape,
91                                                    int device_ordinal) {
92   VLOG(1) << "Outfeeding literal from device " << device_ordinal
93           << "; shape: " << shape;
94   return client_->TransferFromOutfeed(&shape, device_ordinal);
95 }
96 
97 /* static */
FromLiteral(const Literal & argument,const absl::optional<Shape> & shape_with_layout,const LocalClient & client,int device_ordinal)98 StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
99     const Literal& argument, const absl::optional<Shape>& shape_with_layout,
100     const LocalClient& client, int device_ordinal) {
101   VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
102           << device_ordinal;
103   auto literal_to_buffer = [&](const Literal& arg) {
104     return client.client()->LiteralToShapedBuffer(
105         arg, device_ordinal, client.client()->backend().memory_allocator());
106   };
107 
108   StatusOr<ScopedShapedBuffer> buf = [&] {
109     if (shape_with_layout) {
110       Literal relaid = argument.Relayout(shape_with_layout.value());
111       return literal_to_buffer(relaid);
112     }
113     return literal_to_buffer(argument);
114   }();
115   TF_RETURN_IF_ERROR(buf.status());
116   return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client());
117 }
118 
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,xla::LocalClient * client)119 LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
120                                      xla::LocalClient* client)
121     : shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
122 
shaped_buffer() const123 const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
124   return &shaped_buffer_;
125 }
126 
Release()127 ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); }
128 
shape() const129 const Shape& LocalShapedBuffer::shape() const {
130   return shaped_buffer()->on_device_shape();
131 }
132 
ToLiteral() const133 StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
134   return client_->ShapedBufferToLiteral(*shaped_buffer());
135 }
136 
LocalShapedBufferTuple(std::vector<LocalShapedBuffer * > elements)137 LocalShapedBufferTuple::LocalShapedBufferTuple(
138     std::vector<LocalShapedBuffer*> elements)
139     : elements_(std::move(elements)) {
140   for (auto* element : elements_) {
141     CHECK(element != nullptr);
142   }
143 }
144 
~LocalShapedBufferTuple()145 LocalShapedBufferTuple::~LocalShapedBufferTuple() {
146   for (LocalShapedBuffer* element : elements_) {
147     if (element != nullptr) {
148       delete element;
149     }
150   }
151 }
152 
Release(int i)153 StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
154   LocalShapedBuffer* element = elements_[i];
155   if (element == nullptr) {
156     return InvalidArgument("Attempted to release already-released element %d.",
157                            i);
158   }
159   elements_[i] = nullptr;
160   return element;
161 }
162 
size() const163 int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
164 
DestructureTuple()165 StatusOr<LocalShapedBufferTuple*> LocalShapedBuffer::DestructureTuple() {
166   const Shape tuple_shape = shape();
167 
168   if (!tuple_shape.IsTuple()) {
169     return InvalidArgument(
170         "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
171         "shape; shape: %s",
172         ShapeUtil::HumanString(tuple_shape));
173   }
174 
175   DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator();
176   ShapedBuffer tuple_buffer = Release();
177 
178   // Extract some metadata we use to construct scoped buffers.
179   const se::Platform* platform = tuple_buffer.platform();
180   int device_ordinal = tuple_buffer.device_ordinal();
181 
182   ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
183   std::vector<LocalShapedBuffer*> results;
184   for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
185     // Create a shaped buffer for this destructured tuple element.
186     const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
187     VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
188     ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
189 
190     ShapeUtil::ForEachSubshape(
191         subshape, [&](const Shape& s, const ShapeIndex& index) {
192           ShapeIndex original(index);
193           original.push_front(i);
194           se::DeviceMemoryBase* device_memory =
195               shape_tree.mutable_element(original);
196           shaped_buffer.set_buffer(*device_memory, index);
197           *device_memory = se::DeviceMemoryBase();
198         });
199 
200     VLOG(3) << "Completed tuple element: " << i;
201     results.push_back(new LocalShapedBuffer(
202         ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
203   }
204   // Deallocate the root buffer.
205   se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
206   TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
207   return new LocalShapedBufferTuple(std::move(results));
208 }
209 
LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable,xla::DeviceAssignment device_assignment,xla::LocalClient * client)210 LocalExecutable::LocalExecutable(
211     std::unique_ptr<xla::LocalExecutable> executable,
212     xla::DeviceAssignment device_assignment, xla::LocalClient* client)
213     : executable_(std::move(executable)),
214       device_assignment_(std::move(device_assignment)),
215       client_(client) {}
216 
DeviceOrdinals() const217 std::vector<int> LocalExecutable::DeviceOrdinals() const {
218   int num_replicas = device_assignment_.replica_count();
219   std::vector<int> device_ordinals;
220   device_ordinals.reserve(num_replicas);
221   for (int i = 0; i < num_replicas; ++i) {
222     device_ordinals.push_back(device_assignment_(i, 0));
223   }
224   return device_ordinals;
225 }
226 
Execute(absl::Span<LocalShapedBuffer * const> argument_handles)227 StatusOr<LocalShapedBuffer*> LocalExecutable::Execute(
228     absl::Span<LocalShapedBuffer* const> argument_handles) {
229   if (num_replicas() != 1) {
230     return InvalidArgument(
231         "Attempted to execute computation with %d replicas using Execute()",
232         num_replicas());
233   }
234   StatusOr<ScopedShapedBuffer> result_buffer_status;
235   const int device_ordinal = device_assignment_(0, 0);
236   VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
237           << device_ordinal;
238 
239   std::vector<const ShapedBuffer*> argument_buffers;
240   argument_buffers.reserve(argument_handles.size());
241   for (auto& handle : argument_handles) {
242     argument_buffers.push_back(handle->shaped_buffer());
243   }
244 
245   ExecutableRunOptions options;
246   options.set_device_ordinal(device_ordinal);
247   options.set_allocator(client_->backend().memory_allocator());
248   options.set_intra_op_thread_pool(
249       client_->backend().eigen_intra_op_thread_pool_device());
250   options.set_device_assignment(&device_assignment_);
251 
252   result_buffer_status = executable_->Run(argument_buffers, options);
253 
254   if (!result_buffer_status.ok()) {
255     return InternalError(
256         "Failed running replica 0 (other replicas may have failed as well): "
257         "%s.",
258         result_buffer_status.status().ToString());
259   }
260   return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
261                                client_);
262 }
263 
ExecutePerReplica(absl::Span<const std::vector<LocalShapedBuffer * >> argument_handles)264 StatusOr<LocalShapedBufferTuple*> LocalExecutable::ExecutePerReplica(
265     absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
266   const int num_devices = client_->device_count();
267 
268   if (argument_handles.size() != num_replicas()) {
269     return InvalidArgument(
270         "Attempted to execute with %d replicas when replica count is %d",
271         argument_handles.size(), num_devices);
272   }
273   if (argument_handles.size() > num_devices) {
274     return InvalidArgument(
275         "Attempted to execute with %d replicas when device count is %d",
276         argument_handles.size(), num_devices);
277   }
278 
279   VLOG(1) << "Executing with " << num_replicas() << " replicas.";
280 
281   std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
282   auto execute = [this, &argument_handles, &results](int replica) {
283     const int device_ordinal = device_assignment_(replica, 0);
284     VLOG(3) << "Replica " << replica
285             << " mapped to device ordinal for execution: " << device_ordinal;
286 
287     std::vector<const ShapedBuffer*> argument_buffers;
288     argument_buffers.reserve(argument_handles[replica].size());
289     for (auto& handle : argument_handles[replica]) {
290       argument_buffers.push_back(handle->shaped_buffer());
291     }
292 
293     ExecutableRunOptions options;
294     options.set_device_ordinal(device_ordinal);
295     options.set_allocator(client_->backend().memory_allocator());
296     options.set_intra_op_thread_pool(
297         client_->backend().eigen_intra_op_thread_pool_device());
298     options.set_device_assignment(&device_assignment_);
299     StatusOr<ScopedShapedBuffer> result_buffer_status =
300         executable_->Run(argument_buffers, options);
301 
302     results[replica] = std::move(result_buffer_status);
303   };
304 
305   if (num_replicas() == 1) {
306     // Fast-path if there is only one replica — run the computation on the
307     // current thread.
308     execute(0);
309   } else {
310     // TODO(phawkins): don't recreate the threadpool for each execution.
311     tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
312                                         num_replicas() - 1);
313 
314     for (int replica = 0; replica < num_replicas() - 1; ++replica) {
315       pool.Schedule([&execute, replica] { execute(replica); });
316     }
317     execute(num_replicas() - 1);
318   }
319 
320   std::vector<LocalShapedBuffer*> wrapped_results(num_replicas());
321   for (int replica = 0; replica < num_replicas(); ++replica) {
322     auto& statusor = results[replica];
323     if (!statusor.ok()) {
324       return InternalError(
325           "Failed running replica %d (other replicas may have failed as well): "
326           "%s.",
327           replica, statusor.status().ToString());
328     }
329     wrapped_results[replica] =
330         new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
331   }
332 
333   return new LocalShapedBufferTuple(std::move(wrapped_results));
334 }
335 
Computation(XlaComputation computation)336 Computation::Computation(XlaComputation computation)
337     : computation_(std::move(computation)) {}
338 
Compile(const std::vector<Shape> & argument_shapes,const ExecutableBuildOptions * build_options,const LocalClient & client)339 StatusOr<LocalExecutable*> Computation::Compile(
340     const std::vector<Shape>& argument_shapes,
341     const ExecutableBuildOptions* build_options, const LocalClient& client) {
342   std::vector<const Shape*> argument_shape_pointers;
343   argument_shape_pointers.reserve(argument_shapes.size());
344   for (auto& argument_shape : argument_shapes) {
345     argument_shape_pointers.push_back(&argument_shape);
346   }
347 
348   ExecutableBuildOptions options;
349   if (build_options != nullptr) {
350     options = *build_options;
351   }
352   TF_ASSIGN_OR_RETURN(
353       auto local_executable,
354       client.client()->Compile(computation_, argument_shape_pointers, options));
355   TF_ASSIGN_OR_RETURN(
356       DeviceAssignment device_assignment,
357       client.client()->backend().computation_placer()->AssignDevices(
358           options.num_replicas(), /*computation_count=*/1));
359 
360   return new LocalExecutable(std::move(local_executable),
361                              std::move(device_assignment), client.client());
362 }
363 
computation() const364 const XlaComputation& Computation::computation() const { return computation_; }
365 
GetSerializedProto() const366 string Computation::GetSerializedProto() const {
367   string result;
368   if (!computation_.proto().SerializeToString(&result)) {
369     LOG(ERROR) << "Failed to serialize the HloModuleProto.";
370     return "";
371   }
372   return result;
373 }
374 
GetHloText() const375 StatusOr<string> Computation::GetHloText() const {
376   TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
377                       HloModule::CreateModuleConfigFromProto(
378                           computation_.proto(), GetDebugOptionsFromFlags()));
379   TF_ASSIGN_OR_RETURN(
380       std::unique_ptr<HloModule> hlo_module,
381       HloModule::CreateFromProto(computation_.proto(), module_config));
382   HloPrintOptions options;
383   options = HloPrintOptions::ShortParsable();
384   options.set_print_large_constants(false);
385   return hlo_module->ToString(options);
386 }
387 
GetHloDotGraph() const388 StatusOr<string> Computation::GetHloDotGraph() const {
389   TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
390                       HloModule::CreateModuleConfigFromProto(
391                           computation_.proto(), GetDebugOptionsFromFlags()));
392   TF_ASSIGN_OR_RETURN(
393       std::unique_ptr<HloModule> hlo_module,
394       HloModule::CreateFromProto(computation_.proto(), module_config));
395   return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
396                      hlo_module->config().debug_options(),
397                      RenderedGraphFormat::kDot);
398 }
399 
GetProgramShape() const400 StatusOr<ProgramShape> Computation::GetProgramShape() const {
401   return computation_.GetProgramShape();
402 }
403 
GetReturnValueShape() const404 StatusOr<Shape> Computation::GetReturnValueShape() const {
405   TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape());
406   return std::move(*shape.mutable_result());
407 }
408 
LocalOp(const XlaOp & op)409 LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
410 
op() const411 const XlaOp& LocalOp::op() const { return op_; }
412 
ComputationBuilder(const string & computation_name)413 ComputationBuilder::ComputationBuilder(const string& computation_name)
414     : builder_(computation_name) {}
415 
SetOpMetadata(const OpMetadata & metadata)416 void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
417   builder_.SetOpMetadata(metadata);
418 }
419 
ClearOpMetadata()420 void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
421 
Build()422 StatusOr<Computation*> ComputationBuilder::Build() {
423   TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
424   return new Computation(std::move(computation));
425 }
426 
Parameter(int64 parameter_number,const Shape & shape,const string & name)427 LocalOp ComputationBuilder::Parameter(int64 parameter_number,
428                                       const Shape& shape, const string& name) {
429   return xla::Parameter(&builder_, parameter_number, shape, name);
430 }
431 
BuildWithRoot(const LocalOp & root)432 StatusOr<Computation*> ComputationBuilder::BuildWithRoot(const LocalOp& root) {
433   TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op()));
434   return new Computation(std::move(computation));
435 }
436 
GetShape(const LocalOp & operand)437 StatusOr<Shape> ComputationBuilder::GetShape(const LocalOp& operand) {
438   return builder_.GetShape(operand.op());
439 }
440 
GetReturnValueShape()441 StatusOr<Shape> ComputationBuilder::GetReturnValueShape() {
442   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
443   return program_shape.result();
444 }
445 
ReplicaId()446 LocalOp ComputationBuilder::ReplicaId() { return xla::ReplicaId(&builder_); }
447 
Infeed(const Shape & shape)448 LocalOp ComputationBuilder::Infeed(const Shape& shape) {
449   return xla::Infeed(&builder_, shape);
450 }
451 
Outfeed(const LocalOp & operand,const Shape & shape,const string & outfeed_config)452 void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape,
453                                  const string& outfeed_config) {
454   xla::Outfeed(operand.op(), shape, outfeed_config);
455 }
456 
ConstantLiteral(const Literal & literal)457 LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) {
458   return xla::ConstantLiteral(&builder_, literal);
459 }
460 
Iota(PrimitiveType element_type,int64 size)461 LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
462   return xla::Iota(&builder_, element_type, size);
463 }
464 
BroadcastedIota(const Shape & shape,int64 dimension)465 LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape,
466                                             int64 dimension) {
467   return xla::Iota(&builder_, shape, dimension);
468 }
469 
Broadcast(const LocalOp & operand,absl::Span<const int64> broadcast_sizes)470 LocalOp ComputationBuilder::Broadcast(const LocalOp& operand,
471                                       absl::Span<const int64> broadcast_sizes) {
472   return xla::Broadcast(operand.op(), broadcast_sizes);
473 }
474 
BroadcastInDim(const LocalOp & operand,absl::Span<const int64> out_dim_sizes,absl::Span<const int64> broadcast_dimensions)475 LocalOp ComputationBuilder::BroadcastInDim(
476     const LocalOp& operand, absl::Span<const int64> out_dim_sizes,
477     absl::Span<const int64> broadcast_dimensions) {
478   return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions);
479 }
480 
Pad(const LocalOp & operand,const LocalOp & padding_value,const PaddingConfig & padding_config)481 LocalOp ComputationBuilder::Pad(const LocalOp& operand,
482                                 const LocalOp& padding_value,
483                                 const PaddingConfig& padding_config) {
484   return xla::Pad(operand.op(), padding_value.op(), padding_config);
485 }
486 
Reshape(const LocalOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)487 LocalOp ComputationBuilder::Reshape(const LocalOp& operand,
488                                     absl::Span<const int64> dimensions,
489                                     absl::Span<const int64> new_sizes) {
490   return xla::Reshape(operand.op(), dimensions, new_sizes);
491 }
492 
Collapse(const LocalOp & operand,absl::Span<const int64> dimensions)493 LocalOp ComputationBuilder::Collapse(const LocalOp& operand,
494                                      absl::Span<const int64> dimensions) {
495   return xla::Collapse(operand.op(), dimensions);
496 }
497 
AllToAll(const LocalOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,absl::Span<const ReplicaGroup> replica_groups)498 LocalOp ComputationBuilder::AllToAll(
499     const LocalOp& operand, int64 split_dimension, int64 concat_dimension,
500     int64 split_count, absl::Span<const ReplicaGroup> replica_groups) {
501   std::vector<ReplicaGroup> rg;
502   rg.reserve(replica_groups.size());
503   for (int i = 0; i < replica_groups.size(); ++i) {
504     rg.push_back(replica_groups[i]);
505   }
506   return xla::AllToAll(operand.op(), split_dimension, concat_dimension,
507                        split_count, rg);
508 }
509 
CrossReplicaSum(const LocalOp & operand,absl::Span<const ReplicaGroup> replica_groups)510 LocalOp ComputationBuilder::CrossReplicaSum(
511     const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
512   return xla::CrossReplicaSum(operand.op(), replica_groups);
513 }
514 
Slice(const LocalOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)515 LocalOp ComputationBuilder::Slice(const LocalOp& operand,
516                                   absl::Span<const int64> start_indices,
517                                   absl::Span<const int64> limit_indices,
518                                   absl::Span<const int64> strides) {
519   return xla::Slice(operand.op(), start_indices, limit_indices, strides);
520 }
521 
SliceInDim(const LocalOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)522 LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand,
523                                        int64 start_index, int64 limit_index,
524                                        int64 stride, int64 dimno) {
525   return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
526 }
527 
DynamicSlice(const LocalOp & operand,const LocalOp & start_indices,absl::Span<const int64> slice_sizes)528 LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand,
529                                          const LocalOp& start_indices,
530                                          absl::Span<const int64> slice_sizes) {
531   return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
532 }
533 
DynamicUpdateSlice(const LocalOp & operand,const LocalOp & update,const LocalOp & start_indices)534 LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand,
535                                                const LocalOp& update,
536                                                const LocalOp& start_indices) {
537   return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
538 }
539 
ConcatInDim(absl::Span<const LocalOp> operands,int64 dimension)540 LocalOp ComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
541                                         int64 dimension) {
542   std::vector<XlaOp> xla_ops;
543   xla_ops.reserve(operands.size());
544   for (const auto& op : operands) {
545     xla_ops.push_back(op.op());
546   }
547   return xla::ConcatInDim(&builder_, xla_ops, dimension);
548 }
549 
SelectAndScatterWithGeneralPadding(const LocalOp & operand,const Computation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const LocalOp & source,const LocalOp & init_value,const Computation & scatter)550 LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding(
551     const LocalOp& operand, const Computation& select,
552     absl::Span<const int64> window_dimensions,
553     absl::Span<const int64> window_strides,
554     absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
555     const LocalOp& init_value, const Computation& scatter) {
556   return xla::SelectAndScatterWithGeneralPadding(
557       operand.op(), select.computation(), window_dimensions, window_strides,
558       padding, source.op(), init_value.op(), scatter.computation());
559 }
560 
Tuple(absl::Span<const LocalOp> elements)561 LocalOp ComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
562   std::vector<XlaOp> xla_ops;
563   xla_ops.reserve(elements.size());
564   for (const auto& op : elements) {
565     xla_ops.push_back(op.op());
566   }
567 
568   return xla::Tuple(&builder_, xla_ops);
569 }
570 
GetTupleElement(const LocalOp & tuple_data,int64 index)571 LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
572                                             int64 index) {
573   return xla::GetTupleElement(tuple_data.op(), index);
574 }
575 
Dot(const LocalOp & lhs,const LocalOp & rhs)576 LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
577   return xla::Dot(lhs.op(), rhs.op());
578 }
579 
DotGeneral(const LocalOp & lhs,const LocalOp & rhs,const DotDimensionNumbers & dimension_numbers)580 LocalOp ComputationBuilder::DotGeneral(
581     const LocalOp& lhs, const LocalOp& rhs,
582     const DotDimensionNumbers& dimension_numbers) {
583   return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
584 }
585 
ConvGeneralDilated(const LocalOp & lhs,const LocalOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count)586 LocalOp ComputationBuilder::ConvGeneralDilated(
587     const LocalOp& lhs, const LocalOp& rhs,
588     absl::Span<const int64> window_strides,
589     absl::Span<const std::pair<int64, int64>> padding,
590     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
591     const ConvolutionDimensionNumbers& dimension_numbers,
592     int64 feature_group_count) {
593   return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
594                                  lhs_dilation, rhs_dilation, dimension_numbers,
595                                  feature_group_count);
596 }
597 
ConvertElementType(const LocalOp & operand,PrimitiveType new_element_type)598 LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand,
599                                                PrimitiveType new_element_type) {
600   return xla::ConvertElementType(operand.op(), new_element_type);
601 }
602 
BitcastConvertType(const LocalOp & operand,PrimitiveType new_element_type)603 LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand,
604                                                PrimitiveType new_element_type) {
605   return xla::BitcastConvertType(operand.op(), new_element_type);
606 }
607 
Call(const Computation & local_computation,absl::Span<const LocalOp> operands)608 LocalOp ComputationBuilder::Call(const Computation& local_computation,
609                                  absl::Span<const LocalOp> operands) {
610   std::vector<XlaOp> xla_ops;
611   xla_ops.reserve(operands.size());
612   for (const auto& op : operands) {
613     xla_ops.push_back(op.op());
614   }
615   return xla::Call(&builder_, local_computation.computation(), xla_ops);
616 }
617 
CustomCall(const string & call_target_name,absl::Span<const LocalOp> operands,const Shape & shape_with_layout,const std::vector<Shape> & operand_shapes_with_layout,const string & opaque)618 LocalOp ComputationBuilder::CustomCall(
619     const string& call_target_name, absl::Span<const LocalOp> operands,
620     const Shape& shape_with_layout,
621     const std::vector<Shape>& operand_shapes_with_layout,
622     const string& opaque) {
623   std::vector<XlaOp> xla_ops;
624   xla_ops.reserve(operands.size());
625   for (const auto& op : operands) {
626     xla_ops.push_back(op.op());
627   }
628   return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops,
629                                    shape_with_layout,
630                                    operand_shapes_with_layout, opaque);
631 }
632 
Transpose(const LocalOp & operand,absl::Span<const int64> permutation)633 LocalOp ComputationBuilder::Transpose(const LocalOp& operand,
634                                       absl::Span<const int64> permutation) {
635   return xla::Transpose(operand.op(), permutation);
636 }
637 
Rev(const LocalOp & operand,absl::Span<const int64> dimensions)638 LocalOp ComputationBuilder::Rev(const LocalOp& operand,
639                                 absl::Span<const int64> dimensions) {
640   return xla::Rev(operand.op(), dimensions);
641 }
642 
Map(absl::Span<const LocalOp> operands,const Computation & local_computation,absl::Span<const int64> dimensions)643 LocalOp ComputationBuilder::Map(absl::Span<const LocalOp> operands,
644                                 const Computation& local_computation,
645                                 absl::Span<const int64> dimensions) {
646   std::vector<XlaOp> xla_ops;
647   xla_ops.reserve(operands.size());
648   for (const auto& op : operands) {
649     xla_ops.push_back(op.op());
650   }
651 
652   return xla::Map(&builder_, xla_ops, local_computation.computation(),
653                   dimensions);
654 }
655 
Reduce(const LocalOp & operand,const LocalOp & init_value,const Computation & local_computation,absl::Span<const int64> dimensions_to_reduce)656 LocalOp ComputationBuilder::Reduce(
657     const LocalOp& operand, const LocalOp& init_value,
658     const Computation& local_computation,
659     absl::Span<const int64> dimensions_to_reduce) {
660   return xla::Reduce(operand.op(), init_value.op(),
661                      local_computation.computation(), dimensions_to_reduce);
662 }
663 
ReduceWindowWithGeneralPadding(const LocalOp & operand,const LocalOp & init_value,const Computation & local_computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)664 LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding(
665     const LocalOp& operand, const LocalOp& init_value,
666     const Computation& local_computation,
667     absl::Span<const int64> window_dimensions,
668     absl::Span<const int64> window_strides,
669     absl::Span<const int64> base_dilations,
670     absl::Span<const int64> window_dilations,
671     absl::Span<const std::pair<int64, int64>> padding) {
672   return xla::ReduceWindowWithGeneralPadding(
673       operand.op(), init_value.op(), local_computation.computation(),
674       window_dimensions, window_strides, base_dilations, window_dilations,
675       padding);
676 }
677 
RngNormal(const LocalOp & mu,const LocalOp & sigma,const Shape & shape)678 LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma,
679                                       const Shape& shape) {
680   return xla::RngNormal(mu.op(), sigma.op(), shape);
681 }
682 
RngUniform(const LocalOp & a,const LocalOp & b,const Shape & shape)683 LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
684                                        const Shape& shape) {
685   return xla::RngUniform(a.op(), b.op(), shape);
686 }
687 
While(const Computation & condition,const Computation & body,const LocalOp & init)688 LocalOp ComputationBuilder::While(const Computation& condition,
689                                   const Computation& body,
690                                   const LocalOp& init) {
691   return xla::While(condition.computation(), body.computation(), init.op());
692 }
693 
Conditional(const LocalOp & predicate,const LocalOp & true_operand,const Computation & true_computation,const LocalOp & false_operand,const Computation & false_computation)694 LocalOp ComputationBuilder::Conditional(const LocalOp& predicate,
695                                         const LocalOp& true_operand,
696                                         const Computation& true_computation,
697                                         const LocalOp& false_operand,
698                                         const Computation& false_computation) {
699   return xla::Conditional(predicate.op(), true_operand.op(),
700                           true_computation.computation(), false_operand.op(),
701                           false_computation.computation());
702 }
703 
IsConstant(const LocalOp & operand)704 StatusOr<bool> ComputationBuilder::IsConstant(const LocalOp& operand) {
705   return builder_.IsConstant(operand.op());
706 }
707 
Sort(const LocalOp & operand,int64 dimension)708 LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
709   return xla::Sort(operand.op(), {}, dimension);
710 }
711 
SortKeyVal(const LocalOp & keys,const LocalOp & values,int64 dimension)712 LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
713                                        const LocalOp& values, int64 dimension) {
714   return xla::Sort(keys.op(), {values.op()}, dimension);
715 }
716 
Cholesky(const LocalOp & a,bool lower)717 LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) {
718   return xla::Cholesky(a.op(), lower);
719 }
720 
QR(const LocalOp & a,bool full_matrices)721 LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
722   XlaBuilder* builder = a.op().builder();
723   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
724     TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
725     return xla::Tuple(builder, {qr.q, qr.r});
726   });
727 }
728 
Eigh(const LocalOp & a,bool lower)729 LocalOp ComputationBuilder::Eigh(const LocalOp& a, bool lower) {
730   XlaBuilder* builder = a.op().builder();
731   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
732     auto eigh = xla::SelfAdjointEig(a.op(), lower);
733     return xla::Tuple(builder, {eigh.v, eigh.w});
734   });
735 }
736 
SVD(const LocalOp & a)737 LocalOp ComputationBuilder::SVD(const LocalOp& a) {
738   XlaBuilder* builder = a.op().builder();
739   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
740     auto svd = xla::SVD(a.op());
741     return xla::Tuple(builder, {svd.u, svd.d, svd.v});
742   });
743 }
744 
TriangularSolve(const LocalOp & a,const LocalOp & b,bool left_side,bool lower,bool unit_diagonal,int transpose_a)745 LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b,
746                                             bool left_side, bool lower,
747                                             bool unit_diagonal,
748                                             int transpose_a) {
749   return xla::TriangularSolve(
750       a.op(), b.op(), left_side, lower, unit_diagonal,
751       xla::TriangularSolveOptions::Transpose(transpose_a));
752 }
753 
Gather(const LocalOp & input,const LocalOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)754 LocalOp ComputationBuilder::Gather(
755     const LocalOp& input, const LocalOp& start_indices,
756     const GatherDimensionNumbers& dimension_numbers,
757     absl::Span<const int64> slice_sizes) {
758   return xla::Gather(input.op(), start_indices.op(), dimension_numbers,
759                      slice_sizes);
760 }
761 
Scatter(const LocalOp & input,const LocalOp & scatter_indices,const LocalOp & updates,const Computation & update_computation,const ScatterDimensionNumbers & dimension_numbers)762 LocalOp ComputationBuilder::Scatter(
763     const LocalOp& input, const LocalOp& scatter_indices,
764     const LocalOp& updates, const Computation& update_computation,
765     const ScatterDimensionNumbers& dimension_numbers) {
766   return xla::Scatter(input.op(), scatter_indices.op(), updates.op(),
767                       update_computation.computation(), dimension_numbers);
768 }
769 
BuildConstantSubGraph(const LocalOp & operand)770 StatusOr<Computation*> ComputationBuilder::BuildConstantSubGraph(
771     const LocalOp& operand) {
772   TF_ASSIGN_OR_RETURN(XlaComputation computation,
773                       builder_.BuildConstantSubGraph(operand.op()));
774   return new Computation(std::move(computation));
775 }
776 
777 #define _FORWARD(method_name, return_sig, args_sig, args) \
778   return_sig ComputationBuilder::method_name args_sig {   \
779     return xla::method_name args;                         \
780   }
781 
782 #define _FORWARD_UNOP(method_name) \
783   _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))
784 
785 #define _FORWARD_BINOP(method_name)                        \
786   _FORWARD(method_name, LocalOp,                           \
787            (const LocalOp& lhs, const LocalOp& rhs,        \
788             absl::Span<const int64> broadcast_dimensions), \
789            (lhs.op(), rhs.op(), broadcast_dimensions))
790 
791 #define _FORWARD_TRIOP(method_name)                                      \
792   _FORWARD(method_name, LocalOp,                                         \
793            (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \
794            (lhs.op(), rhs.op(), ehs.op()))
795 
796 _FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)797 _FORWARD_TRIOP(Clamp)
798 _FORWARD_BINOP(Eq)
799 _FORWARD_BINOP(Ne)
800 _FORWARD_BINOP(Ge)
801 _FORWARD_BINOP(Gt)
802 _FORWARD_BINOP(Lt)
803 _FORWARD_BINOP(Le)
804 _FORWARD_BINOP(Add)
805 _FORWARD_BINOP(Sub)
806 _FORWARD_BINOP(Mul)
807 _FORWARD_BINOP(Div)
808 _FORWARD_BINOP(Rem)
809 _FORWARD_BINOP(Max)
810 _FORWARD_BINOP(Min)
811 _FORWARD_BINOP(And)
812 _FORWARD_BINOP(Or)
813 _FORWARD_BINOP(Xor)
814 _FORWARD_BINOP(ShiftLeft)
815 _FORWARD_BINOP(ShiftRightArithmetic)
816 _FORWARD_BINOP(ShiftRightLogical)
817 _FORWARD_BINOP(Atan2)
818 _FORWARD_BINOP(Pow)
819 _FORWARD_BINOP(Complex)
820 _FORWARD_UNOP(Not)
821 _FORWARD_UNOP(Clz)
822 _FORWARD_UNOP(Abs)
823 _FORWARD_UNOP(Exp)
824 _FORWARD_UNOP(Expm1)
825 _FORWARD_UNOP(Floor)
826 _FORWARD_UNOP(Ceil)
827 _FORWARD_UNOP(Round)
828 _FORWARD_UNOP(Log)
829 _FORWARD_UNOP(Log1p)
830 _FORWARD_UNOP(Sign)
831 _FORWARD_UNOP(Cos)
832 _FORWARD_UNOP(Sin)
833 _FORWARD_UNOP(Tanh)
834 _FORWARD_UNOP(IsFinite)
835 _FORWARD_UNOP(Neg)
836 _FORWARD_UNOP(Sqrt)
837 _FORWARD_UNOP(Rsqrt)
838 _FORWARD_UNOP(Square)
839 _FORWARD_UNOP(Reciprocal)
840 _FORWARD_UNOP(Erfc)
841 _FORWARD_UNOP(Erf)
842 _FORWARD_UNOP(ErfInv)
843 _FORWARD_UNOP(Lgamma)
844 _FORWARD_UNOP(Digamma)
845 _FORWARD_UNOP(Acos)
846 _FORWARD_UNOP(Asin)
847 _FORWARD_UNOP(Atan)
848 _FORWARD_UNOP(Tan)
849 _FORWARD_UNOP(Acosh)
850 _FORWARD_UNOP(Asinh)
851 _FORWARD_UNOP(Atanh)
852 _FORWARD_UNOP(Cosh)
853 _FORWARD_UNOP(Sinh)
854 _FORWARD_UNOP(Real)
855 _FORWARD_UNOP(Imag)
856 _FORWARD_UNOP(Conj)
857 
858 #undef _FORWARD
859 #undef _FORWARD_UNOP
860 #undef _FORWARD_BINOP
861 #undef _FORWARD_TRIOP
862 
863 void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
864   delete local_shaped_buffer;
865 }
866 
DeleteLocalExecutable(LocalExecutable * computation)867 void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; }
868 
DeleteComputation(Computation * computation)869 void DeleteComputation(Computation* computation) { delete computation; }
870 
871 }  // namespace swig
872 }  // namespace xla
873