1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ 18 19 #include <memory> 20 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 24 #include "mlir/Pass/Pass.h" // from @llvm-project 25 #include "mlir/Pass/PassManager.h" // from @llvm-project 26 #include "tensorflow/compiler/tf2xla/xla_argument.h" 27 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 28 #include "tensorflow/compiler/xla/client/xla_computation.h" 29 #include "tensorflow/core/common_runtime/device.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h" 32 #include "tensorflow/stream_executor/lib/statusor.h" 33 34 namespace tensorflow { 35 36 // Populates the supplied passmanager with the passes required to run the 37 // TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes 38 // can be populated in `custom_legalization_passes`. 39 void CreateConvertMlirToXlaHloPipeline( 40 mlir::OpPassManager& pm, llvm::StringRef device_type, 41 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 42 custom_legalization_passes); 43 44 // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module 45 // should only contain operations in tf dialect. If the input module contains 46 // operation in the tf_executor dialect, for example, returns an error. 47 // Exception to this are tf_executor dialect ops that are optimized away through 48 // canonicalization. 49 // 50 // Operations in tf dialect are lowered to XLA HLO through the following steps: 51 // . Legalizes control flow operations. 52 // . Decomposes compound resource operations so that the only remaining 53 // operations on resource variables are resource reads/writes.. 54 // . Replaces resource reads/writes with function inputs/outputs and 55 // eliminates the use of resource variables. 56 // . Legalizes the operations to XLA HLO operations. 57 // . Canonicalizes the XLA HLO operations. 58 // 59 // device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", 60 // "XLA_GPU_JIT" or "XLA_TPU_JIT". 61 // use_tuple_args: when this is true, always create a tuple argument for the 62 // entry computation. 63 // return_tuple: when this is true, always create a tuple result for the 64 // entry computation. 65 // shape_representation_fn: when this is set, this shape representation function 66 // will be used to determine argument and result shapes. Otherwise the 67 // original shape will be used as is. 68 // custom_legalization_passes: passes to run before the default TF legalization 69 // passes for backend-specific ops. 70 Status ConvertMLIRToXlaComputation( 71 mlir::ModuleOp module_op, llvm::StringRef device_type, 72 xla::XlaComputation* xla_computation, bool use_tuple_args, 73 bool return_tuple, 74 const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, 75 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 76 custom_legalization_passes = {}); 77 78 // Helper struct representing argument tensor or resource handle shapes. 79 struct TensorOrResourceShape { 80 TensorShape shape; 81 bool is_resource = false; 82 }; 83 84 // Refine MLIR types based on new shape information. 85 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes, 86 mlir::ModuleOp module); 87 88 // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level 89 // inputs to module_op that have already been added to the XlaBuilder. returns 90 // are the returned XlaOps. 91 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, 92 llvm::ArrayRef<xla::XlaOp> xla_params, 93 std::vector<xla::XlaOp>& returns, 94 llvm::ArrayRef<TensorOrResourceShape> arg_shapes, 95 llvm::StringRef device_type, 96 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 97 custom_legalization_passes); 98 99 // Apply shape, description, and resource information to inputs and outputs 100 // in the XlaCompilationResult. This should be called after 101 // compilation_result->computation was set. 102 Status PopulateResultIOInfo( 103 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes, 104 bool use_tuple_args, bool use_resource_updates_for_aliases, 105 XlaHelpers::ShapeRepresentationFn shape_representation_fn, 106 XlaCompilationResult* compilation_result); 107 108 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and 109 // stores them in CompilationResult. 110 // TODO(hinsu): Migrate options to separate struct. 111 Status CompileMlirToXlaHlo( 112 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes, 113 llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, 114 bool use_resource_updates_for_aliases, 115 XlaHelpers::ShapeRepresentationFn shape_representation_fn, 116 XlaCompilationResult* compilation_result, 117 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 118 custom_legalization_passes); 119 120 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying 121 // metadata and stores them in CompilationResult. 122 Status CompileSerializedMlirToXlaHlo( 123 llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, 124 llvm::StringRef device_type, bool use_tuple_args, 125 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, 126 XlaCompilationResult* compilation_result, 127 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 128 custom_legalization_passes = {}); 129 130 // Compiles a TensorFlow Graph (already converted to MLIR, imported with 131 // tf_executor dialect still present) into XLA HLO, generates all accompanying 132 // metadata and stores them in CompilationResult. This will rewrite arguments 133 // and run the TensorFlow standard pipeline prior to invoking 134 // `CompileMlirToXlaHlo`. 135 Status CompileGraphToXlaHlo( 136 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args, 137 llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple, 138 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, 139 XlaCompilationResult* compilation_result, 140 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 141 custom_legalization_passes); 142 143 // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata 144 // and stores them in CompilationResult. 145 Status CompileGraphToXlaHlo( 146 const Graph& graph, llvm::ArrayRef<XlaArgument> args, 147 llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type, 148 bool use_tuple_args, const FunctionLibraryDefinition& flib_def, 149 const GraphDebugInfo& debug_info, 150 const XlaHelpers::ShapeRepresentationFn shape_representation_fn, 151 XlaCompilationResult* compilation_result, 152 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 153 custom_legalization_passes = {}); 154 155 // Compiles a Graph from TF to HLO and adds the resulting HLO to the 156 // XlaBuilder. This function adds HLO to a larger HLO computation, so 157 // HLO-level inputs are supplied, and HLO-level outputs are produced. 158 // xla_params is the HLO-level inputs and returns is the HLO-level outputs. 159 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder, 160 llvm::ArrayRef<xla::XlaOp> xla_params, 161 std::vector<xla::XlaOp>& returns, 162 llvm::ArrayRef<XlaArgument> args, 163 llvm::ArrayRef<std::string> control_rets, 164 llvm::StringRef device_type, 165 const FunctionLibraryDefinition& flib_def, 166 const GraphDebugInfo& debug_info, 167 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> 168 custom_legalization_passes = {}); 169 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ 173