1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
18 
19 #include <memory>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "tensorflow/compiler/tf2xla/xla_argument.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 
34 namespace tensorflow {
35 
36 // Populates the supplied passmanager with the passes required to run the
37 // TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes
38 // can be populated in `custom_legalization_passes`.
39 void CreateConvertMlirToXlaHloPipeline(
40     mlir::OpPassManager& pm, llvm::StringRef device_type,
41     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
42         custom_legalization_passes);
43 
44 // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
45 // should only contain operations in tf dialect. If the input module contains
46 // operation in the tf_executor dialect, for example, returns an error.
47 // Exception to this are tf_executor dialect ops that are optimized away through
48 // canonicalization.
49 //
50 // Operations in tf dialect are lowered to XLA HLO through the following steps:
51 //   . Legalizes control flow operations.
52 //   . Decomposes compound resource operations so that the only remaining
53 //     operations on resource variables are resource reads/writes..
54 //   . Replaces resource reads/writes with function inputs/outputs and
55 //     eliminates the use of resource variables.
56 //   . Legalizes the operations to XLA HLO operations.
57 //   . Canonicalizes the XLA HLO operations.
58 //
59 // device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
60 //   "XLA_GPU_JIT" or "XLA_TPU_JIT".
61 // use_tuple_args: when this is true, always create a tuple argument for the
62 //   entry computation.
63 // return_tuple: when this is true, always create a tuple result for the
64 //   entry computation.
65 // shape_representation_fn: when this is set, this shape representation function
66 //   will be used to determine argument and result shapes. Otherwise the
67 //   original shape will be used as is.
68 // custom_legalization_passes: passes to run before the default TF legalization
69 //   passes for backend-specific ops.
70 Status ConvertMLIRToXlaComputation(
71     mlir::ModuleOp module_op, llvm::StringRef device_type,
72     xla::XlaComputation* xla_computation, bool use_tuple_args,
73     bool return_tuple,
74     const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
75     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
76         custom_legalization_passes = {});
77 
78 // Helper struct representing argument tensor or resource handle shapes.
79 struct TensorOrResourceShape {
80   TensorShape shape;
81   bool is_resource = false;
82 };
83 
84 // Refine MLIR types based on new shape information.
85 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
86                     mlir::ModuleOp module);
87 
88 // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
89 // inputs to module_op that have already been added to the XlaBuilder. returns
90 // are the returned XlaOps.
91 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
92                       llvm::ArrayRef<xla::XlaOp> xla_params,
93                       std::vector<xla::XlaOp>& returns,
94                       llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
95                       llvm::StringRef device_type,
96                       llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
97                           custom_legalization_passes);
98 
99 // Apply shape, description, and resource information to inputs and outputs
100 // in the XlaCompilationResult. This should be called after
101 // compilation_result->computation was set.
102 Status PopulateResultIOInfo(
103     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
104     bool use_tuple_args, bool use_resource_updates_for_aliases,
105     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
106     XlaCompilationResult* compilation_result);
107 
108 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
109 // stores them in CompilationResult.
110 // TODO(hinsu): Migrate options to separate struct.
111 Status CompileMlirToXlaHlo(
112     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
113     llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
114     bool use_resource_updates_for_aliases,
115     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
116     XlaCompilationResult* compilation_result,
117     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
118         custom_legalization_passes);
119 
120 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
121 // metadata and stores them in CompilationResult.
122 Status CompileSerializedMlirToXlaHlo(
123     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
124     llvm::StringRef device_type, bool use_tuple_args,
125     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
126     XlaCompilationResult* compilation_result,
127     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
128         custom_legalization_passes = {});
129 
130 // Compiles a TensorFlow Graph (already converted to MLIR, imported with
131 // tf_executor dialect still present) into XLA HLO, generates all accompanying
132 // metadata and stores them in CompilationResult. This will rewrite arguments
133 // and run the TensorFlow standard pipeline prior to invoking
134 // `CompileMlirToXlaHlo`.
135 Status CompileGraphToXlaHlo(
136     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
137     llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
138     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
139     XlaCompilationResult* compilation_result,
140     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
141         custom_legalization_passes);
142 
143 // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
144 // and stores them in CompilationResult.
145 Status CompileGraphToXlaHlo(
146     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
147     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
148     bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
149     const GraphDebugInfo& debug_info,
150     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
151     XlaCompilationResult* compilation_result,
152     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
153         custom_legalization_passes = {});
154 
155 // Compiles a Graph from TF to HLO and adds the resulting HLO to the
156 // XlaBuilder. This function adds HLO to a larger HLO computation, so
157 // HLO-level inputs are supplied, and HLO-level outputs are produced.
158 // xla_params is the HLO-level inputs and returns is the HLO-level outputs.
159 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
160                          llvm::ArrayRef<xla::XlaOp> xla_params,
161                          std::vector<xla::XlaOp>& returns,
162                          llvm::ArrayRef<XlaArgument> args,
163                          llvm::ArrayRef<std::string> control_rets,
164                          llvm::StringRef device_type,
165                          const FunctionLibraryDefinition& flib_def,
166                          const GraphDebugInfo& debug_info,
167                          llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
168                              custom_legalization_passes = {});
169 
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
173