1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/local_computation_builder.h"
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/qr.h"
26 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
27 #include "tensorflow/compiler/xla/client/lib/svd.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/executable_run_options.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/computation_placer.h"
34 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
35 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/gtl/array_slice.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46 namespace swig {
47
RegisterCpuCustomCallTarget(const string & fn_name,PyObject * capsule)48 Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
49 const char* name = "xla._CPU_CUSTOM_CALL_TARGET";
50 if (!PyCapsule_IsValid(capsule, name)) {
51 return InvalidArgument(
52 "Argument to RegisterCpuCustomCallTargetRegistry was not a "
53 "xla._CPU_CUSTOM_CALL_TARGET capsule.");
54 }
55 void* fn_ptr = PyCapsule_GetPointer(capsule, name);
56 CHECK(fn_ptr != nullptr);
57 cpu::CustomCallTargetRegistry::Global()->Register(
58 std::string(fn_name.begin(), fn_name.end()), fn_ptr);
59 return Status::OK();
60 }
61
LocalClient(xla::LocalClient * client)62 LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {}
63
Get(const string & platform_name)64 /* static */ StatusOr<LocalClient> LocalClient::Get(
65 const string& platform_name) {
66 TF_ASSIGN_OR_RETURN(se::Platform * platform,
67 PlatformUtil::GetPlatform(platform_name));
68 if (platform->VisibleDeviceCount() <= 0) {
69 return InvalidArgument("Platform %s has no visible devices.",
70 platform_name);
71 }
72 LocalClientOptions options;
73 options.set_platform(platform);
74 TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
75 ClientLibrary::GetOrCreateLocalClient(options));
76 CHECK(client != nullptr);
77 return LocalClient(client);
78 }
79
80 // Returns the number of devices known to the XLA client.
DeviceCount() const81 int LocalClient::DeviceCount() const { return client_->device_count(); }
82
TransferToInfeed(const Literal & literal,int device_ordinal)83 Status LocalClient::TransferToInfeed(const Literal& literal,
84 int device_ordinal) {
85 VLOG(1) << "Infeeding literal to device " << device_ordinal
86 << "; shape: " << literal.shape();
87 return client_->TransferToInfeed(literal, device_ordinal);
88 }
89
TransferFromOutfeed(const Shape & shape,int device_ordinal)90 StatusOr<Literal> LocalClient::TransferFromOutfeed(const Shape& shape,
91 int device_ordinal) {
92 VLOG(1) << "Outfeeding literal from device " << device_ordinal
93 << "; shape: " << shape;
94 return client_->TransferFromOutfeed(&shape, device_ordinal);
95 }
96
97 /* static */
FromLiteral(const Literal & argument,const absl::optional<Shape> & shape_with_layout,const LocalClient & client,int device_ordinal)98 StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
99 const Literal& argument, const absl::optional<Shape>& shape_with_layout,
100 const LocalClient& client, int device_ordinal) {
101 VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
102 << device_ordinal;
103 auto literal_to_buffer = [&](const Literal& arg) {
104 return client.client()->LiteralToShapedBuffer(
105 arg, device_ordinal, client.client()->backend().memory_allocator());
106 };
107
108 StatusOr<ScopedShapedBuffer> buf = [&] {
109 if (shape_with_layout) {
110 Literal relaid = argument.Relayout(shape_with_layout.value());
111 return literal_to_buffer(relaid);
112 }
113 return literal_to_buffer(argument);
114 }();
115 TF_RETURN_IF_ERROR(buf.status());
116 return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client());
117 }
118
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,xla::LocalClient * client)119 LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
120 xla::LocalClient* client)
121 : shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
122
shaped_buffer() const123 const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
124 return &shaped_buffer_;
125 }
126
Release()127 ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); }
128
shape() const129 const Shape& LocalShapedBuffer::shape() const {
130 return shaped_buffer()->on_device_shape();
131 }
132
ToLiteral() const133 StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
134 return client_->ShapedBufferToLiteral(*shaped_buffer());
135 }
136
LocalShapedBufferTuple(std::vector<LocalShapedBuffer * > elements)137 LocalShapedBufferTuple::LocalShapedBufferTuple(
138 std::vector<LocalShapedBuffer*> elements)
139 : elements_(std::move(elements)) {
140 for (auto* element : elements_) {
141 CHECK(element != nullptr);
142 }
143 }
144
~LocalShapedBufferTuple()145 LocalShapedBufferTuple::~LocalShapedBufferTuple() {
146 for (LocalShapedBuffer* element : elements_) {
147 if (element != nullptr) {
148 delete element;
149 }
150 }
151 }
152
Release(int i)153 StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
154 LocalShapedBuffer* element = elements_[i];
155 if (element == nullptr) {
156 return InvalidArgument("Attempted to release already-released element %d.",
157 i);
158 }
159 elements_[i] = nullptr;
160 return element;
161 }
162
size() const163 int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
164
DestructureTuple()165 StatusOr<LocalShapedBufferTuple*> LocalShapedBuffer::DestructureTuple() {
166 const Shape tuple_shape = shape();
167
168 if (!tuple_shape.IsTuple()) {
169 return InvalidArgument(
170 "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
171 "shape; shape: %s",
172 ShapeUtil::HumanString(tuple_shape));
173 }
174
175 DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator();
176 ShapedBuffer tuple_buffer = Release();
177
178 // Extract some metadata we use to construct scoped buffers.
179 const se::Platform* platform = tuple_buffer.platform();
180 int device_ordinal = tuple_buffer.device_ordinal();
181
182 ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
183 std::vector<LocalShapedBuffer*> results;
184 for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
185 // Create a shaped buffer for this destructured tuple element.
186 const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
187 VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
188 ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
189
190 ShapeUtil::ForEachSubshape(
191 subshape, [&](const Shape& s, const ShapeIndex& index) {
192 ShapeIndex original(index);
193 original.push_front(i);
194 se::DeviceMemoryBase* device_memory =
195 shape_tree.mutable_element(original);
196 shaped_buffer.set_buffer(*device_memory, index);
197 *device_memory = se::DeviceMemoryBase();
198 });
199
200 VLOG(3) << "Completed tuple element: " << i;
201 results.push_back(new LocalShapedBuffer(
202 ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
203 }
204 // Deallocate the root buffer.
205 se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
206 TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
207 return new LocalShapedBufferTuple(std::move(results));
208 }
209
LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable,xla::DeviceAssignment device_assignment,xla::LocalClient * client)210 LocalExecutable::LocalExecutable(
211 std::unique_ptr<xla::LocalExecutable> executable,
212 xla::DeviceAssignment device_assignment, xla::LocalClient* client)
213 : executable_(std::move(executable)),
214 device_assignment_(std::move(device_assignment)),
215 client_(client) {}
216
DeviceOrdinals() const217 std::vector<int> LocalExecutable::DeviceOrdinals() const {
218 int num_replicas = device_assignment_.replica_count();
219 std::vector<int> device_ordinals;
220 device_ordinals.reserve(num_replicas);
221 for (int i = 0; i < num_replicas; ++i) {
222 device_ordinals.push_back(device_assignment_(i, 0));
223 }
224 return device_ordinals;
225 }
226
Execute(absl::Span<LocalShapedBuffer * const> argument_handles)227 StatusOr<LocalShapedBuffer*> LocalExecutable::Execute(
228 absl::Span<LocalShapedBuffer* const> argument_handles) {
229 if (num_replicas() != 1) {
230 return InvalidArgument(
231 "Attempted to execute computation with %d replicas using Execute()",
232 num_replicas());
233 }
234 StatusOr<ScopedShapedBuffer> result_buffer_status;
235 const int device_ordinal = device_assignment_(0, 0);
236 VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
237 << device_ordinal;
238
239 std::vector<const ShapedBuffer*> argument_buffers;
240 argument_buffers.reserve(argument_handles.size());
241 for (auto& handle : argument_handles) {
242 argument_buffers.push_back(handle->shaped_buffer());
243 }
244
245 ExecutableRunOptions options;
246 options.set_device_ordinal(device_ordinal);
247 options.set_allocator(client_->backend().memory_allocator());
248 options.set_intra_op_thread_pool(
249 client_->backend().eigen_intra_op_thread_pool_device());
250 options.set_device_assignment(&device_assignment_);
251
252 result_buffer_status = executable_->Run(argument_buffers, options);
253
254 if (!result_buffer_status.ok()) {
255 return InternalError(
256 "Failed running replica 0 (other replicas may have failed as well): "
257 "%s.",
258 result_buffer_status.status().ToString());
259 }
260 return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
261 client_);
262 }
263
ExecutePerReplica(absl::Span<const std::vector<LocalShapedBuffer * >> argument_handles)264 StatusOr<LocalShapedBufferTuple*> LocalExecutable::ExecutePerReplica(
265 absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
266 const int num_devices = client_->device_count();
267
268 if (argument_handles.size() != num_replicas()) {
269 return InvalidArgument(
270 "Attempted to execute with %d replicas when replica count is %d",
271 argument_handles.size(), num_devices);
272 }
273 if (argument_handles.size() > num_devices) {
274 return InvalidArgument(
275 "Attempted to execute with %d replicas when device count is %d",
276 argument_handles.size(), num_devices);
277 }
278
279 VLOG(1) << "Executing with " << num_replicas() << " replicas.";
280
281 std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
282 auto execute = [this, &argument_handles, &results](int replica) {
283 const int device_ordinal = device_assignment_(replica, 0);
284 VLOG(3) << "Replica " << replica
285 << " mapped to device ordinal for execution: " << device_ordinal;
286
287 std::vector<const ShapedBuffer*> argument_buffers;
288 argument_buffers.reserve(argument_handles[replica].size());
289 for (auto& handle : argument_handles[replica]) {
290 argument_buffers.push_back(handle->shaped_buffer());
291 }
292
293 ExecutableRunOptions options;
294 options.set_device_ordinal(device_ordinal);
295 options.set_allocator(client_->backend().memory_allocator());
296 options.set_intra_op_thread_pool(
297 client_->backend().eigen_intra_op_thread_pool_device());
298 options.set_device_assignment(&device_assignment_);
299 StatusOr<ScopedShapedBuffer> result_buffer_status =
300 executable_->Run(argument_buffers, options);
301
302 results[replica] = std::move(result_buffer_status);
303 };
304
305 if (num_replicas() == 1) {
306 // Fast-path if there is only one replica — run the computation on the
307 // current thread.
308 execute(0);
309 } else {
310 // TODO(phawkins): don't recreate the threadpool for each execution.
311 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
312 num_replicas() - 1);
313
314 for (int replica = 0; replica < num_replicas() - 1; ++replica) {
315 pool.Schedule([&execute, replica] { execute(replica); });
316 }
317 execute(num_replicas() - 1);
318 }
319
320 std::vector<LocalShapedBuffer*> wrapped_results(num_replicas());
321 for (int replica = 0; replica < num_replicas(); ++replica) {
322 auto& statusor = results[replica];
323 if (!statusor.ok()) {
324 return InternalError(
325 "Failed running replica %d (other replicas may have failed as well): "
326 "%s.",
327 replica, statusor.status().ToString());
328 }
329 wrapped_results[replica] =
330 new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
331 }
332
333 return new LocalShapedBufferTuple(std::move(wrapped_results));
334 }
335
Computation(XlaComputation computation)336 Computation::Computation(XlaComputation computation)
337 : computation_(std::move(computation)) {}
338
Compile(const std::vector<Shape> & argument_shapes,const ExecutableBuildOptions * build_options,const LocalClient & client)339 StatusOr<LocalExecutable*> Computation::Compile(
340 const std::vector<Shape>& argument_shapes,
341 const ExecutableBuildOptions* build_options, const LocalClient& client) {
342 std::vector<const Shape*> argument_shape_pointers;
343 argument_shape_pointers.reserve(argument_shapes.size());
344 for (auto& argument_shape : argument_shapes) {
345 argument_shape_pointers.push_back(&argument_shape);
346 }
347
348 ExecutableBuildOptions options;
349 if (build_options != nullptr) {
350 options = *build_options;
351 }
352 TF_ASSIGN_OR_RETURN(
353 auto local_executable,
354 client.client()->Compile(computation_, argument_shape_pointers, options));
355 TF_ASSIGN_OR_RETURN(
356 DeviceAssignment device_assignment,
357 client.client()->backend().computation_placer()->AssignDevices(
358 options.num_replicas(), /*computation_count=*/1));
359
360 return new LocalExecutable(std::move(local_executable),
361 std::move(device_assignment), client.client());
362 }
363
computation() const364 const XlaComputation& Computation::computation() const { return computation_; }
365
GetSerializedProto() const366 string Computation::GetSerializedProto() const {
367 string result;
368 if (!computation_.proto().SerializeToString(&result)) {
369 LOG(ERROR) << "Failed to serialize the HloModuleProto.";
370 return "";
371 }
372 return result;
373 }
374
GetHloText() const375 StatusOr<string> Computation::GetHloText() const {
376 TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
377 HloModule::CreateModuleConfigFromProto(
378 computation_.proto(), GetDebugOptionsFromFlags()));
379 TF_ASSIGN_OR_RETURN(
380 std::unique_ptr<HloModule> hlo_module,
381 HloModule::CreateFromProto(computation_.proto(), module_config));
382 HloPrintOptions options;
383 options = HloPrintOptions::ShortParsable();
384 options.set_print_large_constants(false);
385 return hlo_module->ToString(options);
386 }
387
GetHloDotGraph() const388 StatusOr<string> Computation::GetHloDotGraph() const {
389 TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
390 HloModule::CreateModuleConfigFromProto(
391 computation_.proto(), GetDebugOptionsFromFlags()));
392 TF_ASSIGN_OR_RETURN(
393 std::unique_ptr<HloModule> hlo_module,
394 HloModule::CreateFromProto(computation_.proto(), module_config));
395 return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
396 hlo_module->config().debug_options(),
397 RenderedGraphFormat::kDot);
398 }
399
GetProgramShape() const400 StatusOr<ProgramShape> Computation::GetProgramShape() const {
401 return computation_.GetProgramShape();
402 }
403
GetReturnValueShape() const404 StatusOr<Shape> Computation::GetReturnValueShape() const {
405 TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape());
406 return std::move(*shape.mutable_result());
407 }
408
LocalOp(const XlaOp & op)409 LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
410
op() const411 const XlaOp& LocalOp::op() const { return op_; }
412
ComputationBuilder(const string & computation_name)413 ComputationBuilder::ComputationBuilder(const string& computation_name)
414 : builder_(computation_name) {}
415
SetOpMetadata(const OpMetadata & metadata)416 void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
417 builder_.SetOpMetadata(metadata);
418 }
419
ClearOpMetadata()420 void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
421
Build()422 StatusOr<Computation*> ComputationBuilder::Build() {
423 TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
424 return new Computation(std::move(computation));
425 }
426
Parameter(int64 parameter_number,const Shape & shape,const string & name)427 LocalOp ComputationBuilder::Parameter(int64 parameter_number,
428 const Shape& shape, const string& name) {
429 return xla::Parameter(&builder_, parameter_number, shape, name);
430 }
431
BuildWithRoot(const LocalOp & root)432 StatusOr<Computation*> ComputationBuilder::BuildWithRoot(const LocalOp& root) {
433 TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op()));
434 return new Computation(std::move(computation));
435 }
436
GetShape(const LocalOp & operand)437 StatusOr<Shape> ComputationBuilder::GetShape(const LocalOp& operand) {
438 return builder_.GetShape(operand.op());
439 }
440
GetReturnValueShape()441 StatusOr<Shape> ComputationBuilder::GetReturnValueShape() {
442 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
443 return program_shape.result();
444 }
445
ReplicaId()446 LocalOp ComputationBuilder::ReplicaId() { return xla::ReplicaId(&builder_); }
447
Infeed(const Shape & shape)448 LocalOp ComputationBuilder::Infeed(const Shape& shape) {
449 return xla::Infeed(&builder_, shape);
450 }
451
Outfeed(const LocalOp & operand,const Shape & shape,const string & outfeed_config)452 void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape,
453 const string& outfeed_config) {
454 xla::Outfeed(operand.op(), shape, outfeed_config);
455 }
456
ConstantLiteral(const Literal & literal)457 LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) {
458 return xla::ConstantLiteral(&builder_, literal);
459 }
460
Iota(PrimitiveType element_type,int64 size)461 LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
462 return xla::Iota(&builder_, element_type, size);
463 }
464
BroadcastedIota(const Shape & shape,int64 dimension)465 LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape,
466 int64 dimension) {
467 return xla::Iota(&builder_, shape, dimension);
468 }
469
Broadcast(const LocalOp & operand,absl::Span<const int64> broadcast_sizes)470 LocalOp ComputationBuilder::Broadcast(const LocalOp& operand,
471 absl::Span<const int64> broadcast_sizes) {
472 return xla::Broadcast(operand.op(), broadcast_sizes);
473 }
474
BroadcastInDim(const LocalOp & operand,absl::Span<const int64> out_dim_sizes,absl::Span<const int64> broadcast_dimensions)475 LocalOp ComputationBuilder::BroadcastInDim(
476 const LocalOp& operand, absl::Span<const int64> out_dim_sizes,
477 absl::Span<const int64> broadcast_dimensions) {
478 return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions);
479 }
480
Pad(const LocalOp & operand,const LocalOp & padding_value,const PaddingConfig & padding_config)481 LocalOp ComputationBuilder::Pad(const LocalOp& operand,
482 const LocalOp& padding_value,
483 const PaddingConfig& padding_config) {
484 return xla::Pad(operand.op(), padding_value.op(), padding_config);
485 }
486
Reshape(const LocalOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)487 LocalOp ComputationBuilder::Reshape(const LocalOp& operand,
488 absl::Span<const int64> dimensions,
489 absl::Span<const int64> new_sizes) {
490 return xla::Reshape(operand.op(), dimensions, new_sizes);
491 }
492
Collapse(const LocalOp & operand,absl::Span<const int64> dimensions)493 LocalOp ComputationBuilder::Collapse(const LocalOp& operand,
494 absl::Span<const int64> dimensions) {
495 return xla::Collapse(operand.op(), dimensions);
496 }
497
AllToAll(const LocalOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,absl::Span<const ReplicaGroup> replica_groups)498 LocalOp ComputationBuilder::AllToAll(
499 const LocalOp& operand, int64 split_dimension, int64 concat_dimension,
500 int64 split_count, absl::Span<const ReplicaGroup> replica_groups) {
501 std::vector<ReplicaGroup> rg;
502 rg.reserve(replica_groups.size());
503 for (int i = 0; i < replica_groups.size(); ++i) {
504 rg.push_back(replica_groups[i]);
505 }
506 return xla::AllToAll(operand.op(), split_dimension, concat_dimension,
507 split_count, rg);
508 }
509
CrossReplicaSum(const LocalOp & operand,absl::Span<const ReplicaGroup> replica_groups)510 LocalOp ComputationBuilder::CrossReplicaSum(
511 const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
512 return xla::CrossReplicaSum(operand.op(), replica_groups);
513 }
514
Slice(const LocalOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)515 LocalOp ComputationBuilder::Slice(const LocalOp& operand,
516 absl::Span<const int64> start_indices,
517 absl::Span<const int64> limit_indices,
518 absl::Span<const int64> strides) {
519 return xla::Slice(operand.op(), start_indices, limit_indices, strides);
520 }
521
SliceInDim(const LocalOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)522 LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand,
523 int64 start_index, int64 limit_index,
524 int64 stride, int64 dimno) {
525 return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
526 }
527
DynamicSlice(const LocalOp & operand,const LocalOp & start_indices,absl::Span<const int64> slice_sizes)528 LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand,
529 const LocalOp& start_indices,
530 absl::Span<const int64> slice_sizes) {
531 return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
532 }
533
DynamicUpdateSlice(const LocalOp & operand,const LocalOp & update,const LocalOp & start_indices)534 LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand,
535 const LocalOp& update,
536 const LocalOp& start_indices) {
537 return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
538 }
539
ConcatInDim(absl::Span<const LocalOp> operands,int64 dimension)540 LocalOp ComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
541 int64 dimension) {
542 std::vector<XlaOp> xla_ops;
543 xla_ops.reserve(operands.size());
544 for (const auto& op : operands) {
545 xla_ops.push_back(op.op());
546 }
547 return xla::ConcatInDim(&builder_, xla_ops, dimension);
548 }
549
SelectAndScatterWithGeneralPadding(const LocalOp & operand,const Computation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const LocalOp & source,const LocalOp & init_value,const Computation & scatter)550 LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding(
551 const LocalOp& operand, const Computation& select,
552 absl::Span<const int64> window_dimensions,
553 absl::Span<const int64> window_strides,
554 absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
555 const LocalOp& init_value, const Computation& scatter) {
556 return xla::SelectAndScatterWithGeneralPadding(
557 operand.op(), select.computation(), window_dimensions, window_strides,
558 padding, source.op(), init_value.op(), scatter.computation());
559 }
560
Tuple(absl::Span<const LocalOp> elements)561 LocalOp ComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
562 std::vector<XlaOp> xla_ops;
563 xla_ops.reserve(elements.size());
564 for (const auto& op : elements) {
565 xla_ops.push_back(op.op());
566 }
567
568 return xla::Tuple(&builder_, xla_ops);
569 }
570
GetTupleElement(const LocalOp & tuple_data,int64 index)571 LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
572 int64 index) {
573 return xla::GetTupleElement(tuple_data.op(), index);
574 }
575
Dot(const LocalOp & lhs,const LocalOp & rhs)576 LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
577 return xla::Dot(lhs.op(), rhs.op());
578 }
579
DotGeneral(const LocalOp & lhs,const LocalOp & rhs,const DotDimensionNumbers & dimension_numbers)580 LocalOp ComputationBuilder::DotGeneral(
581 const LocalOp& lhs, const LocalOp& rhs,
582 const DotDimensionNumbers& dimension_numbers) {
583 return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
584 }
585
ConvGeneralDilated(const LocalOp & lhs,const LocalOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count)586 LocalOp ComputationBuilder::ConvGeneralDilated(
587 const LocalOp& lhs, const LocalOp& rhs,
588 absl::Span<const int64> window_strides,
589 absl::Span<const std::pair<int64, int64>> padding,
590 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
591 const ConvolutionDimensionNumbers& dimension_numbers,
592 int64 feature_group_count) {
593 return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
594 lhs_dilation, rhs_dilation, dimension_numbers,
595 feature_group_count);
596 }
597
ConvertElementType(const LocalOp & operand,PrimitiveType new_element_type)598 LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand,
599 PrimitiveType new_element_type) {
600 return xla::ConvertElementType(operand.op(), new_element_type);
601 }
602
BitcastConvertType(const LocalOp & operand,PrimitiveType new_element_type)603 LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand,
604 PrimitiveType new_element_type) {
605 return xla::BitcastConvertType(operand.op(), new_element_type);
606 }
607
Call(const Computation & local_computation,absl::Span<const LocalOp> operands)608 LocalOp ComputationBuilder::Call(const Computation& local_computation,
609 absl::Span<const LocalOp> operands) {
610 std::vector<XlaOp> xla_ops;
611 xla_ops.reserve(operands.size());
612 for (const auto& op : operands) {
613 xla_ops.push_back(op.op());
614 }
615 return xla::Call(&builder_, local_computation.computation(), xla_ops);
616 }
617
CustomCall(const string & call_target_name,absl::Span<const LocalOp> operands,const Shape & shape_with_layout,const std::vector<Shape> & operand_shapes_with_layout,const string & opaque)618 LocalOp ComputationBuilder::CustomCall(
619 const string& call_target_name, absl::Span<const LocalOp> operands,
620 const Shape& shape_with_layout,
621 const std::vector<Shape>& operand_shapes_with_layout,
622 const string& opaque) {
623 std::vector<XlaOp> xla_ops;
624 xla_ops.reserve(operands.size());
625 for (const auto& op : operands) {
626 xla_ops.push_back(op.op());
627 }
628 return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops,
629 shape_with_layout,
630 operand_shapes_with_layout, opaque);
631 }
632
Transpose(const LocalOp & operand,absl::Span<const int64> permutation)633 LocalOp ComputationBuilder::Transpose(const LocalOp& operand,
634 absl::Span<const int64> permutation) {
635 return xla::Transpose(operand.op(), permutation);
636 }
637
Rev(const LocalOp & operand,absl::Span<const int64> dimensions)638 LocalOp ComputationBuilder::Rev(const LocalOp& operand,
639 absl::Span<const int64> dimensions) {
640 return xla::Rev(operand.op(), dimensions);
641 }
642
Map(absl::Span<const LocalOp> operands,const Computation & local_computation,absl::Span<const int64> dimensions)643 LocalOp ComputationBuilder::Map(absl::Span<const LocalOp> operands,
644 const Computation& local_computation,
645 absl::Span<const int64> dimensions) {
646 std::vector<XlaOp> xla_ops;
647 xla_ops.reserve(operands.size());
648 for (const auto& op : operands) {
649 xla_ops.push_back(op.op());
650 }
651
652 return xla::Map(&builder_, xla_ops, local_computation.computation(),
653 dimensions);
654 }
655
Reduce(const LocalOp & operand,const LocalOp & init_value,const Computation & local_computation,absl::Span<const int64> dimensions_to_reduce)656 LocalOp ComputationBuilder::Reduce(
657 const LocalOp& operand, const LocalOp& init_value,
658 const Computation& local_computation,
659 absl::Span<const int64> dimensions_to_reduce) {
660 return xla::Reduce(operand.op(), init_value.op(),
661 local_computation.computation(), dimensions_to_reduce);
662 }
663
ReduceWindowWithGeneralPadding(const LocalOp & operand,const LocalOp & init_value,const Computation & local_computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)664 LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding(
665 const LocalOp& operand, const LocalOp& init_value,
666 const Computation& local_computation,
667 absl::Span<const int64> window_dimensions,
668 absl::Span<const int64> window_strides,
669 absl::Span<const int64> base_dilations,
670 absl::Span<const int64> window_dilations,
671 absl::Span<const std::pair<int64, int64>> padding) {
672 return xla::ReduceWindowWithGeneralPadding(
673 operand.op(), init_value.op(), local_computation.computation(),
674 window_dimensions, window_strides, base_dilations, window_dilations,
675 padding);
676 }
677
RngNormal(const LocalOp & mu,const LocalOp & sigma,const Shape & shape)678 LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma,
679 const Shape& shape) {
680 return xla::RngNormal(mu.op(), sigma.op(), shape);
681 }
682
RngUniform(const LocalOp & a,const LocalOp & b,const Shape & shape)683 LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
684 const Shape& shape) {
685 return xla::RngUniform(a.op(), b.op(), shape);
686 }
687
While(const Computation & condition,const Computation & body,const LocalOp & init)688 LocalOp ComputationBuilder::While(const Computation& condition,
689 const Computation& body,
690 const LocalOp& init) {
691 return xla::While(condition.computation(), body.computation(), init.op());
692 }
693
Conditional(const LocalOp & predicate,const LocalOp & true_operand,const Computation & true_computation,const LocalOp & false_operand,const Computation & false_computation)694 LocalOp ComputationBuilder::Conditional(const LocalOp& predicate,
695 const LocalOp& true_operand,
696 const Computation& true_computation,
697 const LocalOp& false_operand,
698 const Computation& false_computation) {
699 return xla::Conditional(predicate.op(), true_operand.op(),
700 true_computation.computation(), false_operand.op(),
701 false_computation.computation());
702 }
703
IsConstant(const LocalOp & operand)704 StatusOr<bool> ComputationBuilder::IsConstant(const LocalOp& operand) {
705 return builder_.IsConstant(operand.op());
706 }
707
Sort(const LocalOp & operand,int64 dimension)708 LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
709 return xla::Sort(operand.op(), {}, dimension);
710 }
711
SortKeyVal(const LocalOp & keys,const LocalOp & values,int64 dimension)712 LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
713 const LocalOp& values, int64 dimension) {
714 return xla::Sort(keys.op(), {values.op()}, dimension);
715 }
716
Cholesky(const LocalOp & a,bool lower)717 LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) {
718 return xla::Cholesky(a.op(), lower);
719 }
720
QR(const LocalOp & a,bool full_matrices)721 LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
722 XlaBuilder* builder = a.op().builder();
723 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
724 TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
725 return xla::Tuple(builder, {qr.q, qr.r});
726 });
727 }
728
Eigh(const LocalOp & a,bool lower)729 LocalOp ComputationBuilder::Eigh(const LocalOp& a, bool lower) {
730 XlaBuilder* builder = a.op().builder();
731 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
732 auto eigh = xla::SelfAdjointEig(a.op(), lower);
733 return xla::Tuple(builder, {eigh.v, eigh.w});
734 });
735 }
736
SVD(const LocalOp & a)737 LocalOp ComputationBuilder::SVD(const LocalOp& a) {
738 XlaBuilder* builder = a.op().builder();
739 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
740 auto svd = xla::SVD(a.op());
741 return xla::Tuple(builder, {svd.u, svd.d, svd.v});
742 });
743 }
744
TriangularSolve(const LocalOp & a,const LocalOp & b,bool left_side,bool lower,bool unit_diagonal,int transpose_a)745 LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b,
746 bool left_side, bool lower,
747 bool unit_diagonal,
748 int transpose_a) {
749 return xla::TriangularSolve(
750 a.op(), b.op(), left_side, lower, unit_diagonal,
751 xla::TriangularSolveOptions::Transpose(transpose_a));
752 }
753
Gather(const LocalOp & input,const LocalOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)754 LocalOp ComputationBuilder::Gather(
755 const LocalOp& input, const LocalOp& start_indices,
756 const GatherDimensionNumbers& dimension_numbers,
757 absl::Span<const int64> slice_sizes) {
758 return xla::Gather(input.op(), start_indices.op(), dimension_numbers,
759 slice_sizes);
760 }
761
Scatter(const LocalOp & input,const LocalOp & scatter_indices,const LocalOp & updates,const Computation & update_computation,const ScatterDimensionNumbers & dimension_numbers)762 LocalOp ComputationBuilder::Scatter(
763 const LocalOp& input, const LocalOp& scatter_indices,
764 const LocalOp& updates, const Computation& update_computation,
765 const ScatterDimensionNumbers& dimension_numbers) {
766 return xla::Scatter(input.op(), scatter_indices.op(), updates.op(),
767 update_computation.computation(), dimension_numbers);
768 }
769
BuildConstantSubGraph(const LocalOp & operand)770 StatusOr<Computation*> ComputationBuilder::BuildConstantSubGraph(
771 const LocalOp& operand) {
772 TF_ASSIGN_OR_RETURN(XlaComputation computation,
773 builder_.BuildConstantSubGraph(operand.op()));
774 return new Computation(std::move(computation));
775 }
776
777 #define _FORWARD(method_name, return_sig, args_sig, args) \
778 return_sig ComputationBuilder::method_name args_sig { \
779 return xla::method_name args; \
780 }
781
782 #define _FORWARD_UNOP(method_name) \
783 _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))
784
785 #define _FORWARD_BINOP(method_name) \
786 _FORWARD(method_name, LocalOp, \
787 (const LocalOp& lhs, const LocalOp& rhs, \
788 absl::Span<const int64> broadcast_dimensions), \
789 (lhs.op(), rhs.op(), broadcast_dimensions))
790
791 #define _FORWARD_TRIOP(method_name) \
792 _FORWARD(method_name, LocalOp, \
793 (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \
794 (lhs.op(), rhs.op(), ehs.op()))
795
796 _FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)797 _FORWARD_TRIOP(Clamp)
798 _FORWARD_BINOP(Eq)
799 _FORWARD_BINOP(Ne)
800 _FORWARD_BINOP(Ge)
801 _FORWARD_BINOP(Gt)
802 _FORWARD_BINOP(Lt)
803 _FORWARD_BINOP(Le)
804 _FORWARD_BINOP(Add)
805 _FORWARD_BINOP(Sub)
806 _FORWARD_BINOP(Mul)
807 _FORWARD_BINOP(Div)
808 _FORWARD_BINOP(Rem)
809 _FORWARD_BINOP(Max)
810 _FORWARD_BINOP(Min)
811 _FORWARD_BINOP(And)
812 _FORWARD_BINOP(Or)
813 _FORWARD_BINOP(Xor)
814 _FORWARD_BINOP(ShiftLeft)
815 _FORWARD_BINOP(ShiftRightArithmetic)
816 _FORWARD_BINOP(ShiftRightLogical)
817 _FORWARD_BINOP(Atan2)
818 _FORWARD_BINOP(Pow)
819 _FORWARD_BINOP(Complex)
820 _FORWARD_UNOP(Not)
821 _FORWARD_UNOP(Clz)
822 _FORWARD_UNOP(Abs)
823 _FORWARD_UNOP(Exp)
824 _FORWARD_UNOP(Expm1)
825 _FORWARD_UNOP(Floor)
826 _FORWARD_UNOP(Ceil)
827 _FORWARD_UNOP(Round)
828 _FORWARD_UNOP(Log)
829 _FORWARD_UNOP(Log1p)
830 _FORWARD_UNOP(Sign)
831 _FORWARD_UNOP(Cos)
832 _FORWARD_UNOP(Sin)
833 _FORWARD_UNOP(Tanh)
834 _FORWARD_UNOP(IsFinite)
835 _FORWARD_UNOP(Neg)
836 _FORWARD_UNOP(Sqrt)
837 _FORWARD_UNOP(Rsqrt)
838 _FORWARD_UNOP(Square)
839 _FORWARD_UNOP(Reciprocal)
840 _FORWARD_UNOP(Erfc)
841 _FORWARD_UNOP(Erf)
842 _FORWARD_UNOP(ErfInv)
843 _FORWARD_UNOP(Lgamma)
844 _FORWARD_UNOP(Digamma)
845 _FORWARD_UNOP(Acos)
846 _FORWARD_UNOP(Asin)
847 _FORWARD_UNOP(Atan)
848 _FORWARD_UNOP(Tan)
849 _FORWARD_UNOP(Acosh)
850 _FORWARD_UNOP(Asinh)
851 _FORWARD_UNOP(Atanh)
852 _FORWARD_UNOP(Cosh)
853 _FORWARD_UNOP(Sinh)
854 _FORWARD_UNOP(Real)
855 _FORWARD_UNOP(Imag)
856 _FORWARD_UNOP(Conj)
857
858 #undef _FORWARD
859 #undef _FORWARD_UNOP
860 #undef _FORWARD_BINOP
861 #undef _FORWARD_TRIOP
862
863 void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
864 delete local_shaped_buffer;
865 }
866
DeleteLocalExecutable(LocalExecutable * computation)867 void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; }
868
DeleteComputation(Computation * computation)869 void DeleteComputation(Computation* computation) { delete computation; }
870
871 } // namespace swig
872 } // namespace xla
873