1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/xla_compiler.h"
17 
18 #include <cstdint>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/hash/hash.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "pybind11/attr.h"
27 #include "pybind11/cast.h"
28 #include "pybind11/numpy.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/pytypes.h"
31 #include "pybind11/stl_bind.h"
32 #include "tensorflow/compiler/xla/client/executable_build_options.h"
33 #include "tensorflow/compiler/xla/client/xla_builder.h"
34 #include "tensorflow/compiler/xla/client/xla_computation.h"
35 #include "tensorflow/compiler/xla/debug_options_flags.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/python/py_client.h"
38 #include "tensorflow/compiler/xla/python/types.h"
39 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
40 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_parser.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/service/platform_util.h"
46 #include "tensorflow/compiler/xla/shape.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 
52 namespace xla {
53 namespace {
54 
55 namespace py = pybind11;
56 
57 struct Uniquer {
58   absl::Mutex mu;
59   NameUniquer name_uniquer TF_GUARDED_BY(mu);
60 };
61 
GetUniquer()62 Uniquer* GetUniquer() {
63   static Uniquer* uniquer = new Uniquer;
64   return uniquer;
65 }
66 
UniquifyName(const std::string & name)67 static std::string UniquifyName(const std::string& name) {
68   Uniquer* uniquer = GetUniquer();
69   absl::MutexLock lock(&uniquer->mu);
70   return uniquer->name_uniquer.GetUniqueName(name);
71 }
72 
73 // Converts a computation to a serialized HloModuleProto.
GetComputationSerializedProto(const XlaComputation & computation)74 StatusOr<py::bytes> GetComputationSerializedProto(
75     const XlaComputation& computation) {
76   std::string result;
77   if (!computation.proto().SerializeToString(&result)) {
78     return Unknown("Failed to serialize the HloModuleProto.");
79   }
80   return py::bytes(result);
81 }
82 
GetHloModule(const XlaComputation & computation)83 StatusOr<std::shared_ptr<HloModule>> GetHloModule(
84     const XlaComputation& computation) {
85   TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
86                       HloModule::CreateModuleConfigFromProto(
87                           computation.proto(), GetDebugOptionsFromFlags()));
88   TF_ASSIGN_OR_RETURN(
89       std::unique_ptr<HloModule> module,
90       HloModule::CreateFromProto(computation.proto(), module_config));
91   return std::shared_ptr<HloModule>(std::move(module));
92 }
93 
94 // Converts a computation to textual HLO form.
GetComputationHloText(const XlaComputation & computation)95 StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
96   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
97                       GetHloModule(computation));
98   HloPrintOptions options;
99   options = HloPrintOptions::ShortParsable();
100   options.set_print_large_constants(false);
101   return hlo_module->ToString(options);
102 }
103 
104 // Converts a computation to HLO dot graph form.
GetComputationHloDotGraph(const XlaComputation & computation)105 StatusOr<std::string> GetComputationHloDotGraph(
106     const XlaComputation& computation) {
107   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
108                       GetHloModule(computation));
109   return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
110                      hlo_module->config().debug_options(),
111                      RenderedGraphFormat::kDot);
112 }
113 
114 // Hashes the HLO module.
HashComputation(const XlaComputation & computation)115 StatusOr<uint64> HashComputation(const XlaComputation& computation) {
116   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
117                       GetHloModule(computation));
118   return hlo_module->Hash();
119 }
120 // Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
121 // invalid input.
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dims,absl::optional<absl::Span<const int64>> minor_to_major)122 StatusOr<Shape> MakeShapeWithLayout(
123     PrimitiveType element_type, absl::Span<const int64> dims,
124     absl::optional<absl::Span<const int64>> minor_to_major) {
125   TF_ASSIGN_OR_RETURN(Shape shape,
126                       ShapeUtil::MakeValidatedShape(element_type, dims));
127   if (minor_to_major) {
128     *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
129     TF_RETURN_IF_ERROR(
130         LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
131   } else {
132     shape.clear_layout();
133   }
134   return shape;
135 }
136 
137 // Registers a 'fn_capsule' as a CPU custom call target.
138 // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
139 // with name "xla._CUSTOM_CALL_TARGET".
140 // 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
PyRegisterCustomCallTarget(const std::string & fn_name,py::capsule capsule,const std::string & platform)141 Status PyRegisterCustomCallTarget(const std::string& fn_name,
142                                   py::capsule capsule,
143                                   const std::string& platform) {
144   static const char* const kName = "xla._CUSTOM_CALL_TARGET";
145   // TODO(phawkins): remove old name after fixing users.
146   static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
147   if (absl::string_view(capsule.name()) != kName &&
148       absl::string_view(capsule.name()) != kOldCpuName) {
149     return InvalidArgument(
150         "Argument to RegisterCustomCallTargetRegistry was not a "
151         "xla._CUSTOM_CALL_TARGET capsule.");
152   }
153   CustomCallTargetRegistry::Global()->Register(
154       fn_name, static_cast<void*>(capsule), platform);
155   return Status::OK();
156 }
157 
158 }  // namespace
159 
BuildXlaCompilerSubmodule(py::module & m)160 void BuildXlaCompilerSubmodule(py::module& m) {
161   // Shapes
162   py::class_<Shape> shape_class(m, "Shape");
163   shape_class
164       .def(py::init([](const string& s) {
165         return absl::make_unique<Shape>(ValueOrThrow(ParseShape(s)));
166       }))
167       .def_static(
168           "tuple_shape",
169           [](std::vector<Shape> shapes) -> Shape {
170             return ShapeUtil::MakeTupleShape(shapes);
171           },
172           "Constructs a tuple shape.")
173       .def_static(
174           "array_shape",
175           [](PrimitiveType type, py::object dims_seq,
176              absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
177             std::vector<int64> dims = IntSequenceToVector(dims_seq);
178             if (layout_seq) {
179               std::vector<int64> layout = IntSequenceToVector(*layout_seq);
180               return MakeShapeWithLayout(type, dims, layout);
181             } else {
182               return MakeShapeWithLayout(type, dims, absl::nullopt);
183             }
184           },
185           "Constructs an array shape.", py::arg("type"), py::arg("dims"),
186           py::arg("layout") = absl::nullopt)
187       .def_static(
188           "array_shape",
189           [](py::dtype dtype, py::object dims_seq,
190              absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
191             PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
192             std::vector<int64> dims = IntSequenceToVector(dims_seq);
193             if (layout_seq) {
194               std::vector<int64> layout = IntSequenceToVector(*layout_seq);
195               return MakeShapeWithLayout(type, dims, layout);
196             } else {
197               return MakeShapeWithLayout(type, dims, absl::nullopt);
198             }
199           },
200           "Constructs an array shape.", py::arg("type"), py::arg("dims"),
201           py::arg("layout") = absl::nullopt)
202       .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); })
203       .def("dimensions",
204            [](const Shape& shape) -> py::tuple {
205              return IntSpanToTuple(shape.dimensions());
206            })
207       .def("xla_element_type", &Shape::element_type)
208       .def("element_type",
209            [](const Shape& shape) {
210              return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
211            })
212       .def("numpy_dtype",
213            [](const Shape& shape) {
214              if (shape.IsTuple()) {
215                return py::dtype("O");
216              }
217              return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
218            })
219       .def("is_tuple", &Shape::IsTuple)
220       .def("is_array", &Shape::IsArray)
221       .def("rank", &Shape::rank)
222       .def("to_serialized_proto",
223            [](const Shape& shape) {
224              ShapeProto proto = shape.ToProto();
225              return py::bytes(proto.SerializeAsString());
226            })
227       .def("tuple_shapes",
228            [](const Shape& shape) {
229              return std::vector<Shape>(shape.tuple_shapes());
230            })
231       .def("leaf_count",
232            [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); })
233       .def(
234           "with_major_to_minor_layout_if_absent",
235           [](const Shape& shape) {
236             Shape out = shape;
237             ShapeUtil::ForEachMutableSubshape(
238                 &out, [](Shape* subshape, const ShapeIndex&) {
239                   if (!subshape->has_layout()) {
240                     LayoutUtil::SetToDefaultLayout(subshape);
241                   }
242                 });
243             return out;
244           },
245           "Returns a copy of a shape with missing layouts set to "
246           "major-to-minor.")
247       .def("__eq__", [](const Shape& shape,
248                         const Shape& other) { return shape == other; })
249       .def("__ne__", [](const Shape& shape,
250                         const Shape& other) { return shape != other; })
251       .def("__hash__",
252            [](const Shape& shape) { return absl::Hash<Shape>()(shape); })
253       .def("__repr__", [](const Shape& shape) {
254         return shape.ToString(/*print_layout=*/true);
255       });
256 
257   py::class_<ProgramShape>(m, "ProgramShape")
258       .def(py::init(
259           [](absl::Span<const Shape> params, Shape result) -> ProgramShape {
260             ProgramShape program_shape;
261             for (const Shape& param : params) {
262               *program_shape.add_parameters() = param;
263             }
264             *program_shape.mutable_result() = result;
265             return program_shape;
266           }))
267       .def("parameter_shapes",
268            static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
269                &ProgramShape::parameters))
270       .def("result_shape", &ProgramShape::result)
271       .def("__repr__", &ProgramShape::ToString);
272 
273   // Literals
274   py::class_<Literal, std::shared_ptr<Literal>>(m, "Literal")
275       .def("__repr__", &Literal::ToString);
276 
277   py::class_<XlaComputation>(m, "XlaComputation")
278       .def(py::init([](const py::bytes& serialized_hlo_module_proto)
279                         -> std::unique_ptr<XlaComputation> {
280         HloModuleProto proto;
281         proto.ParseFromString(std::string(serialized_hlo_module_proto));
282         return absl::make_unique<XlaComputation>(proto);
283       }))
284       .def("get_hlo_module", &GetHloModule)
285       .def("program_shape", &XlaComputation::GetProgramShape)
286       .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
287       .def("as_hlo_text", &GetComputationHloText)
288       .def("as_hlo_dot_graph", &GetComputationHloDotGraph)
289       .def("hash", &HashComputation)
290       .def("as_hlo_module", &GetHloModule);
291 
292   py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
293   hlo_print_options_class.def(py::init<>())
294       .def_static("short_parsable", &HloPrintOptions::ShortParsable)
295       .def_static("canonical", &HloPrintOptions::Canonical)
296       .def_static("fingerprint", &HloPrintOptions::Fingerprint)
297       .def_property("print_large_constants",
298                     &HloPrintOptions::print_large_constants,
299                     &HloPrintOptions::set_print_large_constants)
300       .def_property("print_metadata", &HloPrintOptions::print_metadata,
301                     &HloPrintOptions::set_print_metadata)
302       .def_property("print_backend_config",
303                     &HloPrintOptions::print_backend_config,
304                     &HloPrintOptions::set_print_backend_config)
305       .def_property("print_result_shape", &HloPrintOptions::print_result_shape,
306                     &HloPrintOptions::set_print_result_shape)
307       .def_property("print_operand_shape",
308                     &HloPrintOptions::print_operand_shape,
309                     &HloPrintOptions::set_print_operand_shape)
310       .def_property("print_operand_names",
311                     &HloPrintOptions::print_operand_names,
312                     &HloPrintOptions::set_print_operand_names)
313       .def_property("print_ids", &HloPrintOptions::print_ids,
314                     &HloPrintOptions::set_print_ids)
315       .def_property("print_extra_attributes",
316                     &HloPrintOptions::print_extra_attributes,
317                     &HloPrintOptions::set_print_extra_attributes)
318       .def_property("print_program_shape",
319                     &HloPrintOptions::print_program_shape,
320                     &HloPrintOptions::set_print_program_shape)
321       .def_property("print_percent", &HloPrintOptions::print_percent,
322                     &HloPrintOptions::set_print_percent)
323       .def_property("print_control_dependencies",
324                     &HloPrintOptions::print_control_dependencies,
325                     &HloPrintOptions::set_print_control_dependencies)
326       .def_property("compact_operands", &HloPrintOptions::compact_operands,
327                     &HloPrintOptions::set_compact_operands)
328       .def_property("include_layout_in_shapes",
329                     &HloPrintOptions::include_layout_in_shapes,
330                     &HloPrintOptions::set_include_layout_in_shapes)
331       .def_property("canonicalize_instruction_names",
332                     &HloPrintOptions::canonicalize_instruction_names,
333                     &HloPrintOptions::set_canonicalize_instruction_names)
334       .def_property("canonicalize_computations",
335                     &HloPrintOptions::canonicalize_computations,
336                     &HloPrintOptions::set_canonicalize_computations)
337       .def_property("indent_amount", &HloPrintOptions::indent_amount,
338                     &HloPrintOptions::set_indent_amount)
339       .def_property("is_in_nested_computation",
340                     &HloPrintOptions::is_in_nested_computation,
341                     &HloPrintOptions::set_is_in_nested_computation)
342       .def_property(
343           "leading_and_trailing_instructions_number",
344           &HloPrintOptions::leading_and_trailing_instructions_number,
345           &HloPrintOptions::set_leading_and_trailing_instructions_number);
346 
347   py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
348       m, "HloModule");
349   hlo_module_class.def(
350       "to_string",
351       static_cast<std::string (HloModule::*)(const HloPrintOptions&) const>(
352           &HloModule::ToString),
353       py::arg("options") = HloPrintOptions());
354 
355   m.def("hlo_module_to_dot_graph",
356         [](const HloModule& hlo_module) -> StatusOr<std::string> {
357           return RenderGraph(*hlo_module.entry_computation(), /*label=*/"",
358                              hlo_module.config().debug_options(),
359                              RenderedGraphFormat::kDot);
360         });
361   m.def(
362       "hlo_module_cost_analysis",
363       [](PyClient* client,
364          const HloModule& module) -> StatusOr<std::map<string, float>> {
365         TF_ASSIGN_OR_RETURN(auto analysis,
366                             client->pjrt_client()->GetHloCostAnalysis());
367         TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
368         return analysis->properties();
369       });
370 
371   py::class_<XlaOp> xla_op_class(m, "XlaOp");
372 
373   py::class_<XlaBuilder>(m, "XlaBuilder")
374       .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
375         return absl::make_unique<XlaBuilder>(UniquifyName(name));
376       }))
377       // TODO(phawkins): delete capitalized names after updating callers.
378       .def(
379           "Build",
380           [](XlaBuilder& builder, absl::optional<XlaOp> root) {
381             return root ? builder.Build(*root) : builder.Build();
382           },
383           "Builds a computation from the contents of the builder.",
384           py::arg("root") = absl::nullopt)
385       .def("GetShape", &XlaBuilder::GetShape)
386       .def(
387           "build",
388           [](XlaBuilder& builder, absl::optional<XlaOp> root) {
389             return root ? builder.Build(*root) : builder.Build();
390           },
391           "Builds a computation from the contents of the builder.",
392           py::arg("root") = absl::nullopt)
393       .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
394       .def("get_shape", &XlaBuilder::GetShape)
395       .def(
396           "get_program_shape",
397           [](const XlaBuilder& builder,
398              absl::optional<XlaOp> root) -> StatusOr<ProgramShape> {
399             return root ? builder.GetProgramShape(*root)
400                         : builder.GetProgramShape();
401           },
402           py::arg("root") = absl::nullopt)
403       .def("is_constant", &XlaBuilder::IsConstant)
404       .def("set_op_metadata", &XlaBuilder::SetOpMetadata)
405       .def("set_sharding", &XlaBuilder::SetSharding)
406       .def("clear_sharding", &XlaBuilder::ClearSharding)
407       .def("setup_alias",
408            [](XlaBuilder& builder, const std::vector<int64>& output_index,
409               int64 param_number, const std::vector<int64>& param_index) {
410              builder.SetUpAlias(
411                  ShapeIndex(output_index.begin(), output_index.end()),
412                  param_number,
413                  ShapeIndex(param_index.begin(), param_index.end()));
414            });
415 
416   // Device assignments
417   py::class_<DeviceAssignment>(m, "DeviceAssignment")
418       .def_static("create",
419                   [](py::array_t<int> array) -> StatusOr<DeviceAssignment> {
420                     if (array.ndim() != 2) {
421                       return InvalidArgument(
422                           "Argument to DeviceAssignment constructor must be a "
423                           "2D array, received an %dD array.",
424                           array.ndim());
425                     }
426                     DeviceAssignment result(array.shape(0), array.shape(1));
427                     for (int i = 0; i < array.shape(0); ++i) {
428                       for (int j = 0; j < array.shape(1); ++j) {
429                         result(i, j) = array.at(i, j);
430                       }
431                     }
432                     return result;
433                   })
434       .def("replica_count", &DeviceAssignment::replica_count)
435       .def("computation_count", &DeviceAssignment::computation_count)
436       .def("__repr__", &DeviceAssignment::ToString);
437 
438   py::class_<CompileOptions> compile_options(m, "CompileOptions");
439   compile_options
440       .def(py::init([]() -> CompileOptions {
441         CompileOptions options;
442         DebugOptions* debug_options =
443             options.executable_build_options.mutable_debug_options();
444         // Sets fast-math-disabling default options expected by JAX.
445         debug_options->set_xla_cpu_enable_fast_min_max(false);
446         debug_options->set_xla_gpu_enable_fast_min_max(false);
447         return options;
448       }))
449       .def_readwrite("argument_layouts", &CompileOptions::argument_layouts)
450       .def_readwrite("parameter_is_tupled_arguments",
451                      &CompileOptions::parameter_is_tupled_arguments)
452       .def_readonly("executable_build_options",
453                     &CompileOptions::executable_build_options)
454       // TODO(phawkins): the following fields exist for backward compatibility.
455       // Remove them after JAX has been updated not to use them.
456       .def_readwrite("tuple_arguments",
457                      &CompileOptions::parameter_is_tupled_arguments)
458       .def_property(
459           "num_replicas",
460           [](const CompileOptions& options) {
461             return options.executable_build_options.num_replicas();
462           },
463           [](CompileOptions& options, int num_replicas) {
464             options.executable_build_options.set_num_replicas(num_replicas);
465           })
466       .def_property(
467           "num_partitions",
468           [](const CompileOptions& options) {
469             return options.executable_build_options.num_partitions();
470           },
471           [](CompileOptions& options, int num_partitions) {
472             options.executable_build_options.set_num_partitions(num_partitions);
473           })
474       .def_property(
475           "device_assignment",
476           [](const CompileOptions& options)
477               -> absl::optional<DeviceAssignment> {
478             return options.executable_build_options.has_device_assignment()
479                        ? absl::optional<DeviceAssignment>(
480                              options.executable_build_options
481                                  .device_assignment())
482                        : absl::nullopt;
483           },
484           [](CompileOptions& options,
485              const DeviceAssignment& device_assignment) {
486             options.executable_build_options.set_device_assignment(
487                 device_assignment);
488           });
489 
490   // Custom-call targets.
491   m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
492 
493   py::class_<DebugOptions>(m, "DebugOptions")
494       .def("__repr__", &DebugOptions::DebugString)
495       .def_property("xla_cpu_enable_fast_math",
496                     &DebugOptions::xla_cpu_enable_fast_math,
497                     &DebugOptions::set_xla_cpu_enable_fast_math)
498       .def_property("xla_cpu_fast_math_honor_infs",
499                     &DebugOptions::xla_cpu_fast_math_honor_infs,
500                     &DebugOptions::set_xla_cpu_fast_math_honor_infs)
501       .def_property("xla_cpu_fast_math_honor_nans",
502                     &DebugOptions::xla_cpu_fast_math_honor_nans,
503                     &DebugOptions::set_xla_cpu_fast_math_honor_nans)
504       .def_property("xla_cpu_fast_math_honor_division",
505                     &DebugOptions::xla_cpu_fast_math_honor_division,
506                     &DebugOptions::set_xla_cpu_fast_math_honor_division)
507       .def_property("xla_cpu_fast_math_honor_functions",
508                     &DebugOptions::xla_cpu_fast_math_honor_functions,
509                     &DebugOptions::set_xla_cpu_fast_math_honor_functions)
510       .def_property("xla_gpu_enable_fast_min_max",
511                     &DebugOptions::xla_gpu_enable_fast_min_max,
512                     &DebugOptions::set_xla_gpu_enable_fast_min_max)
513       .def_property("xla_backend_optimization_level",
514                     &DebugOptions::xla_backend_optimization_level,
515                     &DebugOptions::set_xla_backend_optimization_level)
516       .def_property("xla_cpu_enable_xprof_traceme",
517                     &DebugOptions::xla_cpu_enable_xprof_traceme,
518                     &DebugOptions::set_xla_cpu_enable_xprof_traceme)
519       .def_property("xla_llvm_disable_expensive_passes",
520                     &DebugOptions::xla_llvm_disable_expensive_passes,
521                     &DebugOptions::set_xla_llvm_disable_expensive_passes)
522       .def_property("xla_test_all_input_layouts",
523                     &DebugOptions::xla_test_all_input_layouts,
524                     &DebugOptions::set_xla_test_all_input_layouts);
525 
526   py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
527       .def(py::init<>())
528       .def("__repr__", &ExecutableBuildOptions::ToString)
529       .def_property(
530           "result_layout",
531           [](const ExecutableBuildOptions& options) -> absl::optional<Shape> {
532             return options.result_layout()
533                        ? absl::optional<Shape>(*options.result_layout())
534                        : absl::nullopt;
535           },
536           &ExecutableBuildOptions::set_result_layout)
537       .def_property("num_replicas", &ExecutableBuildOptions::num_replicas,
538                     &ExecutableBuildOptions::set_num_replicas)
539       .def_property("num_partitions", &ExecutableBuildOptions::num_partitions,
540                     &ExecutableBuildOptions::set_num_partitions)
541       .def_property_readonly(
542           "debug_options", &ExecutableBuildOptions::mutable_debug_options,
543           py::return_value_policy::reference, py::keep_alive<1, 0>())
544       .def_property(
545           "device_assignment",
546           [](const ExecutableBuildOptions& options)
547               -> absl::optional<DeviceAssignment> {
548             return options.has_device_assignment()
549                        ? absl::optional<DeviceAssignment>(
550                              options.device_assignment())
551                        : absl::nullopt;
552           },
553           &ExecutableBuildOptions::set_device_assignment)
554       .def_property("use_spmd_partitioning",
555                     &ExecutableBuildOptions::use_spmd_partitioning,
556                     &ExecutableBuildOptions::set_use_spmd_partitioning);
557 
558   py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
559       .value("DEFAULT", PrecisionConfig::DEFAULT)
560       .value("HIGH", PrecisionConfig::HIGH)
561       .value("HIGHEST", PrecisionConfig::HIGHEST);
562 
563   py::enum_<OpSharding::Type>(m, "OpSharding_Type")
564       .value("REPLICATED", OpSharding::REPLICATED)
565       .value("MAXIMAL", OpSharding::MAXIMAL)
566       .value("TUPLE", OpSharding::TUPLE)
567       .value("OTHER", OpSharding::OTHER);
568 
569   py::enum_<ChannelHandle::ChannelType>(m, "ChannelHandle_ChannelType")
570       .value("CHANNEL_TYPE_INVALID", ChannelHandle::CHANNEL_TYPE_INVALID)
571       .value("DEVICE_TO_DEVICE", ChannelHandle::DEVICE_TO_DEVICE)
572       .value("DEVICE_TO_HOST", ChannelHandle::DEVICE_TO_HOST)
573       .value("HOST_TO_DEVICE", ChannelHandle::HOST_TO_DEVICE);
574 
575   py::class_<ChannelHandle>(m, "ChannelHandle")
576       .def_property_readonly("type", &ChannelHandle::type)
577       .def_property_readonly("handle", &ChannelHandle::handle)
578       .def("__repr__", [](ChannelHandle* h) { return h->DebugString(); });
579 
580   py::enum_<FftType>(m, "FftType")
581       .value("FFT", FftType::FFT)
582       .value("IFFT", FftType::IFFT)
583       .value("RFFT", FftType::RFFT)
584       .value("IRFFT", FftType::IRFFT);
585 }
586 }  // namespace xla
587