/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/xla_compiler.h" #include #include #include #include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl_bind.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/python/py_client.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { namespace py = pybind11; struct Uniquer { absl::Mutex mu; NameUniquer name_uniquer TF_GUARDED_BY(mu); }; Uniquer* GetUniquer() { static Uniquer* uniquer = new Uniquer; return uniquer; } static std::string UniquifyName(const std::string& name) { Uniquer* uniquer = GetUniquer(); absl::MutexLock lock(&uniquer->mu); return uniquer->name_uniquer.GetUniqueName(name); } // Converts a computation to a serialized HloModuleProto. StatusOr GetComputationSerializedProto( const XlaComputation& computation) { std::string result; if (!computation.proto().SerializeToString(&result)) { return Unknown("Failed to serialize the HloModuleProto."); } return py::bytes(result); } StatusOr> GetHloModule( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, HloModule::CreateModuleConfigFromProto( computation.proto(), GetDebugOptionsFromFlags())); TF_ASSIGN_OR_RETURN( std::unique_ptr module, HloModule::CreateFromProto(computation.proto(), module_config)); return std::shared_ptr(std::move(module)); } // Converts a computation to textual HLO form. StatusOr GetComputationHloText(const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, GetHloModule(computation)); HloPrintOptions options; options = HloPrintOptions::ShortParsable(); options.set_print_large_constants(false); return hlo_module->ToString(options); } // Converts a computation to HLO dot graph form. StatusOr GetComputationHloDotGraph( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, GetHloModule(computation)); return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", hlo_module->config().debug_options(), RenderedGraphFormat::kDot); } // Hashes the HLO module. StatusOr HashComputation(const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, GetHloModule(computation)); return hlo_module->Hash(); } // Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on // invalid input. StatusOr MakeShapeWithLayout( PrimitiveType element_type, absl::Span dims, absl::optional> minor_to_major) { TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dims)); if (minor_to_major) { *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); } else { shape.clear_layout(); } return shape; } // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, // with name "xla._CUSTOM_CALL_TARGET". // 'platform' is an XLA platform name, e.g., "Host" or "CUDA". Status PyRegisterCustomCallTarget(const std::string& fn_name, py::capsule capsule, const std::string& platform) { static const char* const kName = "xla._CUSTOM_CALL_TARGET"; // TODO(phawkins): remove old name after fixing users. static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET"; if (absl::string_view(capsule.name()) != kName && absl::string_view(capsule.name()) != kOldCpuName) { return InvalidArgument( "Argument to RegisterCustomCallTargetRegistry was not a " "xla._CUSTOM_CALL_TARGET capsule."); } CustomCallTargetRegistry::Global()->Register( fn_name, static_cast(capsule), platform); return Status::OK(); } } // namespace void BuildXlaCompilerSubmodule(py::module& m) { // Shapes py::class_ shape_class(m, "Shape"); shape_class .def(py::init([](const string& s) { return absl::make_unique(ValueOrThrow(ParseShape(s))); })) .def_static( "tuple_shape", [](std::vector shapes) -> Shape { return ShapeUtil::MakeTupleShape(shapes); }, "Constructs a tuple shape.") .def_static( "array_shape", [](PrimitiveType type, py::object dims_seq, absl::optional layout_seq) -> StatusOr { std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); return MakeShapeWithLayout(type, dims, layout); } else { return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), py::arg("layout") = absl::nullopt) .def_static( "array_shape", [](py::dtype dtype, py::object dims_seq, absl::optional layout_seq) -> StatusOr { PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); return MakeShapeWithLayout(type, dims, layout); } else { return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), py::arg("layout") = absl::nullopt) .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) .def("dimensions", [](const Shape& shape) -> py::tuple { return IntSpanToTuple(shape.dimensions()); }) .def("xla_element_type", &Shape::element_type) .def("element_type", [](const Shape& shape) { return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); }) .def("numpy_dtype", [](const Shape& shape) { if (shape.IsTuple()) { return py::dtype("O"); } return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); }) .def("is_tuple", &Shape::IsTuple) .def("is_array", &Shape::IsArray) .def("rank", &Shape::rank) .def("to_serialized_proto", [](const Shape& shape) { ShapeProto proto = shape.ToProto(); return py::bytes(proto.SerializeAsString()); }) .def("tuple_shapes", [](const Shape& shape) { return std::vector(shape.tuple_shapes()); }) .def("leaf_count", [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) .def( "with_major_to_minor_layout_if_absent", [](const Shape& shape) { Shape out = shape; ShapeUtil::ForEachMutableSubshape( &out, [](Shape* subshape, const ShapeIndex&) { if (!subshape->has_layout()) { LayoutUtil::SetToDefaultLayout(subshape); } }); return out; }, "Returns a copy of a shape with missing layouts set to " "major-to-minor.") .def("__eq__", [](const Shape& shape, const Shape& other) { return shape == other; }) .def("__ne__", [](const Shape& shape, const Shape& other) { return shape != other; }) .def("__hash__", [](const Shape& shape) { return absl::Hash()(shape); }) .def("__repr__", [](const Shape& shape) { return shape.ToString(/*print_layout=*/true); }); py::class_(m, "ProgramShape") .def(py::init( [](absl::Span params, Shape result) -> ProgramShape { ProgramShape program_shape; for (const Shape& param : params) { *program_shape.add_parameters() = param; } *program_shape.mutable_result() = result; return program_shape; })) .def("parameter_shapes", static_cast& (ProgramShape::*)() const>( &ProgramShape::parameters)) .def("result_shape", &ProgramShape::result) .def("__repr__", &ProgramShape::ToString); // Literals py::class_>(m, "Literal") .def("__repr__", &Literal::ToString); py::class_(m, "XlaComputation") .def(py::init([](const py::bytes& serialized_hlo_module_proto) -> std::unique_ptr { HloModuleProto proto; proto.ParseFromString(std::string(serialized_hlo_module_proto)); return absl::make_unique(proto); })) .def("get_hlo_module", &GetHloModule) .def("program_shape", &XlaComputation::GetProgramShape) .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) .def("as_hlo_text", &GetComputationHloText) .def("as_hlo_dot_graph", &GetComputationHloDotGraph) .def("hash", &HashComputation) .def("as_hlo_module", &GetHloModule); py::class_ hlo_print_options_class(m, "HloPrintOptions"); hlo_print_options_class.def(py::init<>()) .def_static("short_parsable", &HloPrintOptions::ShortParsable) .def_static("canonical", &HloPrintOptions::Canonical) .def_static("fingerprint", &HloPrintOptions::Fingerprint) .def_property("print_large_constants", &HloPrintOptions::print_large_constants, &HloPrintOptions::set_print_large_constants) .def_property("print_metadata", &HloPrintOptions::print_metadata, &HloPrintOptions::set_print_metadata) .def_property("print_backend_config", &HloPrintOptions::print_backend_config, &HloPrintOptions::set_print_backend_config) .def_property("print_result_shape", &HloPrintOptions::print_result_shape, &HloPrintOptions::set_print_result_shape) .def_property("print_operand_shape", &HloPrintOptions::print_operand_shape, &HloPrintOptions::set_print_operand_shape) .def_property("print_operand_names", &HloPrintOptions::print_operand_names, &HloPrintOptions::set_print_operand_names) .def_property("print_ids", &HloPrintOptions::print_ids, &HloPrintOptions::set_print_ids) .def_property("print_extra_attributes", &HloPrintOptions::print_extra_attributes, &HloPrintOptions::set_print_extra_attributes) .def_property("print_program_shape", &HloPrintOptions::print_program_shape, &HloPrintOptions::set_print_program_shape) .def_property("print_percent", &HloPrintOptions::print_percent, &HloPrintOptions::set_print_percent) .def_property("print_control_dependencies", &HloPrintOptions::print_control_dependencies, &HloPrintOptions::set_print_control_dependencies) .def_property("compact_operands", &HloPrintOptions::compact_operands, &HloPrintOptions::set_compact_operands) .def_property("include_layout_in_shapes", &HloPrintOptions::include_layout_in_shapes, &HloPrintOptions::set_include_layout_in_shapes) .def_property("canonicalize_instruction_names", &HloPrintOptions::canonicalize_instruction_names, &HloPrintOptions::set_canonicalize_instruction_names) .def_property("canonicalize_computations", &HloPrintOptions::canonicalize_computations, &HloPrintOptions::set_canonicalize_computations) .def_property("indent_amount", &HloPrintOptions::indent_amount, &HloPrintOptions::set_indent_amount) .def_property("is_in_nested_computation", &HloPrintOptions::is_in_nested_computation, &HloPrintOptions::set_is_in_nested_computation) .def_property( "leading_and_trailing_instructions_number", &HloPrintOptions::leading_and_trailing_instructions_number, &HloPrintOptions::set_leading_and_trailing_instructions_number); py::class_> hlo_module_class( m, "HloModule"); hlo_module_class.def( "to_string", static_cast( &HloModule::ToString), py::arg("options") = HloPrintOptions()); m.def("hlo_module_to_dot_graph", [](const HloModule& hlo_module) -> StatusOr { return RenderGraph(*hlo_module.entry_computation(), /*label=*/"", hlo_module.config().debug_options(), RenderedGraphFormat::kDot); }); m.def( "hlo_module_cost_analysis", [](PyClient* client, const HloModule& module) -> StatusOr> { TF_ASSIGN_OR_RETURN(auto analysis, client->pjrt_client()->GetHloCostAnalysis()); TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); return analysis->properties(); }); py::class_ xla_op_class(m, "XlaOp"); py::class_(m, "XlaBuilder") .def(py::init([](const std::string& name) -> std::unique_ptr { return absl::make_unique(UniquifyName(name)); })) // TODO(phawkins): delete capitalized names after updating callers. .def( "Build", [](XlaBuilder& builder, absl::optional root) { return root ? builder.Build(*root) : builder.Build(); }, "Builds a computation from the contents of the builder.", py::arg("root") = absl::nullopt) .def("GetShape", &XlaBuilder::GetShape) .def( "build", [](XlaBuilder& builder, absl::optional root) { return root ? builder.Build(*root) : builder.Build(); }, "Builds a computation from the contents of the builder.", py::arg("root") = absl::nullopt) .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) .def("get_shape", &XlaBuilder::GetShape) .def( "get_program_shape", [](const XlaBuilder& builder, absl::optional root) -> StatusOr { return root ? builder.GetProgramShape(*root) : builder.GetProgramShape(); }, py::arg("root") = absl::nullopt) .def("is_constant", &XlaBuilder::IsConstant) .def("set_op_metadata", &XlaBuilder::SetOpMetadata) .def("set_sharding", &XlaBuilder::SetSharding) .def("clear_sharding", &XlaBuilder::ClearSharding) .def("setup_alias", [](XlaBuilder& builder, const std::vector& output_index, int64 param_number, const std::vector& param_index) { builder.SetUpAlias( ShapeIndex(output_index.begin(), output_index.end()), param_number, ShapeIndex(param_index.begin(), param_index.end())); }); // Device assignments py::class_(m, "DeviceAssignment") .def_static("create", [](py::array_t array) -> StatusOr { if (array.ndim() != 2) { return InvalidArgument( "Argument to DeviceAssignment constructor must be a " "2D array, received an %dD array.", array.ndim()); } DeviceAssignment result(array.shape(0), array.shape(1)); for (int i = 0; i < array.shape(0); ++i) { for (int j = 0; j < array.shape(1); ++j) { result(i, j) = array.at(i, j); } } return result; }) .def("replica_count", &DeviceAssignment::replica_count) .def("computation_count", &DeviceAssignment::computation_count) .def("__repr__", &DeviceAssignment::ToString); py::class_ compile_options(m, "CompileOptions"); compile_options .def(py::init([]() -> CompileOptions { CompileOptions options; DebugOptions* debug_options = options.executable_build_options.mutable_debug_options(); // Sets fast-math-disabling default options expected by JAX. debug_options->set_xla_cpu_enable_fast_min_max(false); debug_options->set_xla_gpu_enable_fast_min_max(false); return options; })) .def_readwrite("argument_layouts", &CompileOptions::argument_layouts) .def_readwrite("parameter_is_tupled_arguments", &CompileOptions::parameter_is_tupled_arguments) .def_readonly("executable_build_options", &CompileOptions::executable_build_options) // TODO(phawkins): the following fields exist for backward compatibility. // Remove them after JAX has been updated not to use them. .def_readwrite("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) .def_property( "num_replicas", [](const CompileOptions& options) { return options.executable_build_options.num_replicas(); }, [](CompileOptions& options, int num_replicas) { options.executable_build_options.set_num_replicas(num_replicas); }) .def_property( "num_partitions", [](const CompileOptions& options) { return options.executable_build_options.num_partitions(); }, [](CompileOptions& options, int num_partitions) { options.executable_build_options.set_num_partitions(num_partitions); }) .def_property( "device_assignment", [](const CompileOptions& options) -> absl::optional { return options.executable_build_options.has_device_assignment() ? absl::optional( options.executable_build_options .device_assignment()) : absl::nullopt; }, [](CompileOptions& options, const DeviceAssignment& device_assignment) { options.executable_build_options.set_device_assignment( device_assignment); }); // Custom-call targets. m.def("register_custom_call_target", &PyRegisterCustomCallTarget); py::class_(m, "DebugOptions") .def("__repr__", &DebugOptions::DebugString) .def_property("xla_cpu_enable_fast_math", &DebugOptions::xla_cpu_enable_fast_math, &DebugOptions::set_xla_cpu_enable_fast_math) .def_property("xla_cpu_fast_math_honor_infs", &DebugOptions::xla_cpu_fast_math_honor_infs, &DebugOptions::set_xla_cpu_fast_math_honor_infs) .def_property("xla_cpu_fast_math_honor_nans", &DebugOptions::xla_cpu_fast_math_honor_nans, &DebugOptions::set_xla_cpu_fast_math_honor_nans) .def_property("xla_cpu_fast_math_honor_division", &DebugOptions::xla_cpu_fast_math_honor_division, &DebugOptions::set_xla_cpu_fast_math_honor_division) .def_property("xla_cpu_fast_math_honor_functions", &DebugOptions::xla_cpu_fast_math_honor_functions, &DebugOptions::set_xla_cpu_fast_math_honor_functions) .def_property("xla_gpu_enable_fast_min_max", &DebugOptions::xla_gpu_enable_fast_min_max, &DebugOptions::set_xla_gpu_enable_fast_min_max) .def_property("xla_backend_optimization_level", &DebugOptions::xla_backend_optimization_level, &DebugOptions::set_xla_backend_optimization_level) .def_property("xla_cpu_enable_xprof_traceme", &DebugOptions::xla_cpu_enable_xprof_traceme, &DebugOptions::set_xla_cpu_enable_xprof_traceme) .def_property("xla_llvm_disable_expensive_passes", &DebugOptions::xla_llvm_disable_expensive_passes, &DebugOptions::set_xla_llvm_disable_expensive_passes) .def_property("xla_test_all_input_layouts", &DebugOptions::xla_test_all_input_layouts, &DebugOptions::set_xla_test_all_input_layouts); py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) .def("__repr__", &ExecutableBuildOptions::ToString) .def_property( "result_layout", [](const ExecutableBuildOptions& options) -> absl::optional { return options.result_layout() ? absl::optional(*options.result_layout()) : absl::nullopt; }, &ExecutableBuildOptions::set_result_layout) .def_property("num_replicas", &ExecutableBuildOptions::num_replicas, &ExecutableBuildOptions::set_num_replicas) .def_property("num_partitions", &ExecutableBuildOptions::num_partitions, &ExecutableBuildOptions::set_num_partitions) .def_property_readonly( "debug_options", &ExecutableBuildOptions::mutable_debug_options, py::return_value_policy::reference, py::keep_alive<1, 0>()) .def_property( "device_assignment", [](const ExecutableBuildOptions& options) -> absl::optional { return options.has_device_assignment() ? absl::optional( options.device_assignment()) : absl::nullopt; }, &ExecutableBuildOptions::set_device_assignment) .def_property("use_spmd_partitioning", &ExecutableBuildOptions::use_spmd_partitioning, &ExecutableBuildOptions::set_use_spmd_partitioning); py::enum_(m, "PrecisionConfig_Precision") .value("DEFAULT", PrecisionConfig::DEFAULT) .value("HIGH", PrecisionConfig::HIGH) .value("HIGHEST", PrecisionConfig::HIGHEST); py::enum_(m, "OpSharding_Type") .value("REPLICATED", OpSharding::REPLICATED) .value("MAXIMAL", OpSharding::MAXIMAL) .value("TUPLE", OpSharding::TUPLE) .value("OTHER", OpSharding::OTHER); py::enum_(m, "ChannelHandle_ChannelType") .value("CHANNEL_TYPE_INVALID", ChannelHandle::CHANNEL_TYPE_INVALID) .value("DEVICE_TO_DEVICE", ChannelHandle::DEVICE_TO_DEVICE) .value("DEVICE_TO_HOST", ChannelHandle::DEVICE_TO_HOST) .value("HOST_TO_DEVICE", ChannelHandle::HOST_TO_DEVICE); py::class_(m, "ChannelHandle") .def_property_readonly("type", &ChannelHandle::type) .def_property_readonly("handle", &ChannelHandle::handle) .def("__repr__", [](ChannelHandle* h) { return h->DebugString(); }); py::enum_(m, "FftType") .value("FFT", FftType::FFT) .value("IFFT", FftType::IFFT) .value("RFFT", FftType::RFFT) .value("IRFFT", FftType::IRFFT); } } // namespace xla