1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
17 
18 #include "absl/types/optional.h"
19 #include "absl/types/variant.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Dialect.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
35 #include "mlir/Transforms/Passes.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
55 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
56 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
57 #include "tensorflow/compiler/tf2xla/shape_util.h"
58 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
59 #include "tensorflow/compiler/xla/shape.h"
60 #include "tensorflow/compiler/xla/xla_data.pb.h"
61 #include "tensorflow/core/framework/tensor_shape.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/tpu/tpu_defs.h"
64 
65 namespace tensorflow {
66 namespace {
67 
68 // Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape,
69 // that is converted to a TensorShape.
GetTensorShapeFromXlaArgument(const XlaArgument & arg)70 StatusOr<TensorShape> GetTensorShapeFromXlaArgument(const XlaArgument& arg) {
71   if (absl::holds_alternative<xla::Shape>(arg.shape)) {
72     TensorShape arg_shape;
73     TF_RETURN_IF_ERROR(
74         XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &arg_shape));
75     return arg_shape;
76   } else {
77     return absl::get<TensorShape>(arg.shape);
78   }
79 }
80 
81 // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
GetXlaInputShapes(mlir::ModuleOp module,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,std::vector<xla::Shape> * xla_input_shapes)82 Status GetXlaInputShapes(
83     mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
84     bool use_tuple_args,
85     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
86     std::vector<xla::Shape>* xla_input_shapes) {
87   xla_input_shapes->clear();
88 
89   mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
90   TF_RET_CHECK(main_func != nullptr) << "No main function found";
91   mlir::FunctionType func_type = main_func.getType();
92 
93   int num_args = func_type.getNumInputs();
94   xla_input_shapes->reserve(num_args);
95 
96   std::vector<xla::Shape> individual_arg_shapes;
97   individual_arg_shapes.reserve(num_args);
98   for (int i = 0; i < num_args; ++i) {
99     individual_arg_shapes.emplace_back();
100     xla::Shape& xla_shape = individual_arg_shapes.back();
101 
102     DataType dtype;
103     TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
104     TF_ASSIGN_OR_RETURN(xla_shape,
105                         shape_representation_fn(arg_shapes[i].shape, dtype,
106                                                 /*use_fast_memory=*/false));
107 
108     // Rewrite layout with sharding, if sharding is set.
109     auto sharding =
110         main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
111     if (!sharding) continue;
112 
113     absl::optional<xla::HloSharding> arg_sharding;
114     xla::OpSharding op_sharding;
115     if (!op_sharding.ParseFromString(sharding.getValue().str()))
116       return errors::InvalidArgument("failed to parse argument sharding ", i,
117                                      " '", sharding.getValue().str(), "'");
118 
119     TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding));
120     TF_RETURN_IF_ERROR(
121         RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
122                                       shape_representation_fn, &xla_shape));
123   }
124   if (use_tuple_args) {
125     xla_input_shapes->push_back(
126         xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
127   } else {
128     *xla_input_shapes = individual_arg_shapes;
129   }
130   return Status::OK();
131 }
132 
133 // Calculates computation output shape and build OutputDescription for each
134 // output based on static shapes in MLIR module. If an output is a resource
135 // write, `resource_updates` is populated insead of `outputs` for that output.
GetOutputInfo(mlir::ModuleOp module,bool use_resource_updates_for_aliases,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_output_shape,std::vector<XlaOutputDescription> * outputs,std::vector<XlaResourceUpdate> * resource_updates)136 Status GetOutputInfo(
137     mlir::ModuleOp module, bool use_resource_updates_for_aliases,
138     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
139     xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
140     std::vector<XlaResourceUpdate>* resource_updates) {
141   auto shape_representation_fn_no_fast_memory =
142       [shape_representation_fn](const TensorShape& shape, DataType dtype) {
143         return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
144       };
145 
146   mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
147   mlir::FunctionType func_type = main_func.getType();
148 
149   outputs->clear();
150   outputs->reserve(func_type.getNumResults());
151   resource_updates->reserve(func_type.getNumResults());
152 
153   std::vector<xla::Shape> shapes;
154   shapes.reserve(func_type.getNumResults());
155 
156   llvm::SmallDenseMap<unsigned, unsigned> output_to_input_alias;
157   for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
158     if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
159             i, "tf.aliasing_output"))
160       output_to_input_alias[aliasing_output.getInt()] = i;
161 
162   for (auto type_and_idx : llvm::enumerate(func_type.getResults())) {
163     TF_ASSIGN_OR_RETURN(
164         xla::Shape shape,
165         xla::TypeToShape(type_and_idx.value(),
166                          shape_representation_fn_no_fast_memory));
167     auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
168     shapes.push_back(shape);
169 
170     auto it = output_to_input_alias.find(type_and_idx.index());
171     if (it != output_to_input_alias.end() && use_resource_updates_for_aliases) {
172       // Add resource write.
173       resource_updates->emplace_back();
174       XlaResourceUpdate& resource_update = resource_updates->back();
175       resource_update.input_index = it->getSecond();
176       resource_update.modified = true;
177       TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
178       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
179       continue;
180     }
181     // Construct OutputDescription for result.
182     outputs->emplace_back();
183     XlaOutputDescription& out_desc = outputs->back();
184     TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
185     // TODO(ycao): Support constant output.
186     out_desc.is_constant = false;
187     TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &out_desc.shape));
188     // Input_index is only meaningful for resource output. Setting it to
189     // meaningless value -1 for non resource outputs.
190     out_desc.input_index =
191         it != output_to_input_alias.end() ? it->getSecond() : -1;
192     // MLIR-based TF-Compiler bridge doesn't support tensorlist output yet.
193     // TODO(ycao): Support tensorlist-type output.
194     out_desc.is_tensor_list = false;
195   }
196 
197   // XLA computation always uses Tuple shape.
198   *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
199   return Status::OK();
200 }
201 
202 // Creates a vector that maps from the parameters of the XLA computation to
203 // their original argument positions.
204 // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
205 // inputs are known constants. Therefore, the input mapping between input to
206 // computation arguments is a trivial in-order 1-1 mapping.
207 // TODO(ycao): Support computation with compile-time constant, which requires
208 // non-trivial input mapping as implemented now.
GetInputMappingForMlir(int num_inputs,std::vector<int> * input_mapping)209 void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
210   input_mapping->resize(num_inputs, 0);
211   std::iota(input_mapping->begin(), input_mapping->end(), 0);
212 }
213 
RegisterDialects(mlir::DialectRegistry & registry)214 static void RegisterDialects(mlir::DialectRegistry& registry) {
215   mlir::RegisterAllTensorFlowDialects(registry);
216   mlir::mhlo::registerAllMhloDialects(registry);
217 }
218 
219 // Checks if functions can be inlined after TF -> HLO legalization. Currently
220 // TPU's are supported, to follow the behavior of inlining functions via the
221 // Graph based bridge in the TPUCompile op kernel.
CanInlineFunctionsPostLegalization(llvm::StringRef device_type)222 bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
223   return device_type == DEVICE_TPU_XLA_JIT;
224 }
225 
226 }  //  namespace
227 
RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,mlir::ModuleOp module)228 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
229                     mlir::ModuleOp module) {
230   auto producer_or = GetTfGraphProducerVersion(module);
231   if (!producer_or.ok()) return producer_or.status();
232   int64_t producer_version = producer_or.ValueOrDie();
233 
234   llvm::SmallVector<int64_t, 16> shape_backing;
235   llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> arg_shapes_copy;
236   {
237     // Convert arg_shapes to a mlir friendly format.
238     size_t count = 0;
239     for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
240       if (tensor_resource_shape.is_resource) continue;
241       count += tensor_resource_shape.shape.dims();
242     }
243     shape_backing.resize(count);
244     arg_shapes_copy.reserve(arg_shapes.size());
245     size_t offset = 0;
246     for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
247       if (tensor_resource_shape.is_resource) {
248         arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
249         continue;
250       }
251       size_t start = offset;
252       for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
253         shape_backing[offset] = dim.size;
254         ++offset;
255       }
256       if (offset == start) {
257         arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
258       } else {
259         arg_shapes_copy.push_back(
260             llvm::ArrayRef<int64_t>(&shape_backing[start], offset - start));
261       }
262     }
263   }
264 
265   auto main_func = module.lookupSymbol<mlir::FuncOp>("main");
266 
267   mlir::StatusScopedDiagnosticHandler error_handler(module.getContext());
268   mlir::LogicalResult result = mlir::TF::InferShapeForFunction(
269       main_func, arg_shapes_copy, producer_version);
270 
271   if (failed(result)) {
272     return error_handler.Combine(
273         errors::Internal("MLIR Shape refinement failed"));
274   }
275   return Status::OK();
276 }
277 
CreateConvertMlirToXlaHloPipeline(mlir::OpPassManager & pm,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)278 void CreateConvertMlirToXlaHloPipeline(
279     mlir::OpPassManager& pm, llvm::StringRef device_type,
280     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
281         custom_legalization_passes) {
282   pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
283   pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateDropWhileShapeInvariantPass());
284   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
285   // The SCCP pass performs constant propagation across the IR, which, for
286   // example, propagates constant arguments into callee functions.
287   pm.addPass(mlir::createSCCPPass());
288   // Guarantee all functions have one use, which enables shape inference.
289   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
290   // Run shape inference pass before tensorlist decomposition to get buffer
291   // shape of uninitialized TensorLists.
292   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
293   pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
294   pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
295   pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
296   pm.addNestedPass<mlir::FuncOp>(
297       mlir::TFDevice::CreateDecomposeResourceOpsPass());
298   pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
299   pm.addPass(mlir::createSymbolDCEPass());
300   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
301   // TODO(b/171426148): We cannot completely remove region to functional control
302   // flow conversion from this pipeline yet as it causes some unit tests to
303   // fail.
304   pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
305   // LegalizeTFControlFlow encapsulates arguments for control flow operations
306   // with a tuple argument which break the assumption of resource lifting
307   // inside PromoteResourcesToArgs.
308   pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
309 
310   pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
311   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
312       /*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
313       /*tf2xla_fallback_device_type=*/device_type));
314   for (auto& target_pass : custom_legalization_passes) {
315     pm.addNestedPass<mlir::FuncOp>(std::move(target_pass));
316   }
317   pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass());
318   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
319   // Run shape inference pass to propagate shapes through tensor_cast operations
320   // from static to dynamic shapes. This could be generated if the shape
321   // inference was originally missing in a TF op but the corresponding HLO op
322   // had static shape after lowering.
323   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
324   // Run LegalizeTFPass again because the previous legalization passes can
325   // expose more graph pruning and canonicalization opportunities that are
326   // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
327   // invocation.
328   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
329       /*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
330       /*tf2xla_fallback_device_type=*/device_type));
331 
332   if (CanInlineFunctionsPostLegalization(device_type))
333     pm.addPass(mlir::createInlinerPass());
334 
335   // In order to export to XLA, we must sink constants to control flow regions,
336   // since XLA uses functional control flow.
337   pm.addNestedPass<mlir::FuncOp>(
338       mlir::mhlo::createSinkConstantsToControlFlowPass());
339 }
340 
LegalizeToHlo(mlir::ModuleOp module_op,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)341 Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
342                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
343                          custom_legalization_passes) {
344   mlir::PassManager tf2xla(module_op.getContext());
345   applyTensorflowAndCLOptions(tf2xla);
346   CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
347                                     custom_legalization_passes);
348 
349   if (VLOG_IS_ON(1))
350     tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op);
351   if (VLOG_IS_ON(2)) {
352     // Print the whole module after each pass which requires disabling
353     // multi-threading as well.
354     module_op.getContext()->disableMultithreading();
355     tf2xla.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
356         /*print_module_scope=*/true));
357   }
358 
359   // Make sure we catch any error reported by MLIR and forward it to the TF
360   // error reporting system. Report a generic error if pass manager failed
361   // without emitting a diagnostic.
362   mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
363 
364   if (failed(tf2xla.run(module_op))) {
365     return error_handler.Combine(
366         errors::InvalidArgument("TF to XLA legalization failed"));
367   }
368 
369   if (VLOG_IS_ON(1))
370     tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op);
371 
372   return Status::OK();
373 }
374 
BuildHloFromTfInner(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)375 Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
376                            llvm::ArrayRef<xla::XlaOp> xla_params,
377                            std::vector<xla::XlaOp>& returns,
378                            llvm::StringRef device_type,
379                            llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
380                                custom_legalization_passes) {
381   TF_RETURN_IF_ERROR(
382       LegalizeToHlo(module_op, device_type, custom_legalization_passes));
383 
384   mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
385   return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
386 }
387 
ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,llvm::StringRef device_type,xla::XlaComputation * xla_computation,bool use_tuple_args,bool return_tuple,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)388 Status ConvertMLIRToXlaComputation(
389     mlir::ModuleOp module_op, llvm::StringRef device_type,
390     xla::XlaComputation* xla_computation, bool use_tuple_args,
391     bool return_tuple,
392     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
393     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
394         custom_legalization_passes) {
395   TF_RETURN_IF_ERROR(
396       LegalizeToHlo(module_op, device_type, custom_legalization_passes));
397 
398   xla::HloProto hlo_proto;
399   TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
400                                                use_tuple_args, return_tuple,
401                                                shape_representation_fn));
402   *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
403   return Status::OK();
404 }
405 
CompileMlirSetup(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,XlaHelpers::ShapeRepresentationFn * shape_representation_fn)406 Status CompileMlirSetup(
407     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
408     XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
409   // Use arg_shapes to improve the mlir type information of `main` in module_op.
410   TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
411 
412   if (VLOG_IS_ON(2))
413     tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op);
414 
415   if (!*shape_representation_fn)
416     *shape_representation_fn = IdentityShapeRepresentationFn();
417 
418   return Status::OK();
419 }
420 
BuildHloFromTf(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)421 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
422                       llvm::ArrayRef<xla::XlaOp> xla_params,
423                       std::vector<xla::XlaOp>& returns,
424                       llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
425                       llvm::StringRef device_type,
426                       llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
427                           custom_legalization_passes) {
428   if (VLOG_IS_ON(2))
429     tensorflow::DumpMlirOpToFile("build_hlo_tf_before", module_op);
430 
431   XlaHelpers::ShapeRepresentationFn shape_representation_fn;
432   TF_RETURN_IF_ERROR(
433       CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
434 
435   // Convert MLIR module to XLA HLO proto contained in XlaComputation.
436   TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
437                                          returns, device_type,
438                                          custom_legalization_passes));
439 
440   if (VLOG_IS_ON(2))
441     tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op);
442 
443   return Status::OK();
444 }
445 
PopulateResultIOInfo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,bool use_resource_updates_for_aliases,XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result)446 Status PopulateResultIOInfo(
447     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
448     bool use_tuple_args, bool use_resource_updates_for_aliases,
449     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
450     XlaCompilationResult* compilation_result) {
451   // Construct mapping from XlaComputation's arg to input edges of execute
452   // node.
453   GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
454 
455   // Compute all input shapes.
456   TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
457                                        shape_representation_fn,
458                                        &compilation_result->xla_input_shapes));
459 
460   // Compute all output descriptions and resource writes
461   return GetOutputInfo(
462       module_op, use_resource_updates_for_aliases, shape_representation_fn,
463       &compilation_result->xla_output_shape, &compilation_result->outputs,
464       &compilation_result->resource_updates);
465 }
466 
CompileMlirToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool use_return_tuple,bool use_resource_updates_for_aliases,XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)467 Status CompileMlirToXlaHlo(
468     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
469     llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
470     bool use_resource_updates_for_aliases,
471     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
472     XlaCompilationResult* compilation_result,
473     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
474         custom_legalization_passes) {
475   TF_RETURN_IF_ERROR(
476       CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
477 
478   // Convert MLIR module to XLA HLO proto contained in XlaComputation.
479   compilation_result->computation = std::make_shared<xla::XlaComputation>();
480   TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
481       module_op, device_type, compilation_result->computation.get(),
482       use_tuple_args, use_return_tuple, shape_representation_fn,
483       custom_legalization_passes));
484 
485   return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
486                               use_resource_updates_for_aliases,
487                               shape_representation_fn, compilation_result);
488 }
489 
CompileSerializedMlirToXlaHlo(llvm::StringRef mlir_module_string,llvm::ArrayRef<TensorShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)490 Status CompileSerializedMlirToXlaHlo(
491     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
492     llvm::StringRef device_type, bool use_tuple_args,
493     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
494     XlaCompilationResult* compilation_result,
495     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
496         custom_legalization_passes) {
497   mlir::DialectRegistry mlir_registry;
498   RegisterDialects(mlir_registry);
499   mlir::MLIRContext mlir_context(mlir_registry);
500   mlir::OwningModuleRef mlir_module;
501 
502   TF_RETURN_IF_ERROR(
503       DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module));
504   llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
505   tensor_or_resource_shapes.reserve(arg_shapes.size());
506   for (const auto& arg_shape : arg_shapes)
507     tensor_or_resource_shapes.push_back({arg_shape});
508   return CompileMlirToXlaHlo(
509       mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
510       /*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
511       shape_representation_fn, compilation_result, custom_legalization_passes);
512 }
513 
514 // Rewrites the given module with specified args. For each of the constant args,
515 // it gets inlined in the "main' function and the corresponding argument is
516 // removed from the signature. For resource args, their subtypes are populated.
517 // Returns the original indices for the other arguments on success.
RewriteWithArgs(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args)518 static StatusOr<std::vector<int>> RewriteWithArgs(
519     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args) {
520   mlir::FuncOp main_fn = module_op.lookupSymbol<mlir::FuncOp>("main");
521   std::vector<int> params;
522 
523   bool has_resource_args = false;
524   auto builder = mlir::OpBuilder(main_fn.getBody());
525   std::vector<int> args_to_erase;
526   for (int idx = 0; idx < args.size(); idx++) {
527     const XlaArgument& xla_arg = args[idx];
528     mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
529     if (xla_arg.kind == XlaArgument::kResource) {
530       mlir::Type element_type;
531       if (xla_arg.type == DT_INVALID) {
532         return errors::Unimplemented(absl::StrCat(
533             "Argument ", idx,
534             " is an uninitialized resource variable which is currently"
535             " unsupported in the MLIR-based TPU bridge"));
536       }
537       TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
538       TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
539                           GetTensorShapeFromXlaArgument(xla_arg));
540       auto resource_shape = arg_shape.dim_sizes();
541       llvm::SmallVector<int64_t, 4> resource_subtype_shape(
542           resource_shape.begin(), resource_shape.end());
543       auto resource_subtype =
544           mlir::RankedTensorType::get(resource_subtype_shape, element_type);
545       auto resource_type =
546           mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
547 
548       auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
549       if (tensor_type.hasRank()) {
550         mlir_arg.setType(
551             mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
552       } else {
553         mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
554       }
555       has_resource_args = true;
556     }
557     if (xla_arg.kind != XlaArgument::kConstant) {
558       params.push_back(idx);
559       continue;
560     }
561 
562     TF_ASSIGN_OR_RETURN(auto value_attr,
563                         ConvertTensor(xla_arg.constant_value, &builder));
564     // TODO(hinsu): Use the actual location of the constant.
565     auto constant = builder.create<mlir::TF::ConstOp>(
566         mlir::UnknownLoc::get(module_op.getContext()), value_attr);
567     mlir_arg.replaceAllUsesWith(constant);
568     args_to_erase.push_back(idx);
569   }
570 
571   if (has_resource_args) {
572     llvm::SmallVector<mlir::Type, 4> updated_argument_types;
573     updated_argument_types.reserve(main_fn.getNumArguments());
574     for (mlir::BlockArgument& arg : main_fn.getArguments())
575       updated_argument_types.push_back(arg.getType());
576 
577     main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
578                                             updated_argument_types,
579                                             main_fn.getType().getResults()));
580   }
581 
582   for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
583 
584   return params;
585 }
586 
CompileGraphSetup(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,std::vector<int> * remaining_params,llvm::SmallVector<TensorOrResourceShape,4> & arg_shapes)587 Status CompileGraphSetup(
588     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
589     std::vector<int>* remaining_params,
590     llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
591   TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
592   arg_shapes.reserve(remaining_params->size());
593   for (unsigned idx : *remaining_params) {
594     const auto& arg = args[idx];
595     TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
596                         GetTensorShapeFromXlaArgument(arg));
597     arg_shapes.push_back({arg_shape,
598                           /*is_resource=*/arg.kind == XlaArgument::kResource});
599   }
600 
601   mlir::PassManager pm(module_op.getContext());
602   applyTensorflowAndCLOptions(pm);
603   mlir::TF::StandardPipelineOptions tf_options;
604   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
605 
606   if (VLOG_IS_ON(1))
607     tensorflow::DumpMlirOpToFile("compile_graph_setup_before", module_op);
608   mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
609   if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
610   if (VLOG_IS_ON(1))
611     tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op);
612 
613   return Status::OK();
614 }
615 
BuildHloFromModule(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)616 Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
617                           llvm::ArrayRef<xla::XlaOp> xla_params,
618                           std::vector<xla::XlaOp>& returns,
619                           llvm::ArrayRef<XlaArgument> args,
620                           llvm::StringRef device_type,
621                           llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
622                               custom_legalization_passes) {
623   std::vector<int> remaining_params;
624   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
625   TF_RETURN_IF_ERROR(
626       CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
627   return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
628                         device_type, custom_legalization_passes);
629 }
630 
CompileGraphToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,bool use_tuple_args,bool use_return_tuple,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)631 Status CompileGraphToXlaHlo(
632     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
633     llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
634     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
635     XlaCompilationResult* compilation_result,
636     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
637         custom_legalization_passes) {
638   std::vector<int> remaining_params;
639   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
640   TF_RETURN_IF_ERROR(
641       CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
642 
643   auto status = CompileMlirToXlaHlo(
644       module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
645       /*use_resource_updates_for_aliases=*/true, shape_representation_fn,
646       compilation_result, custom_legalization_passes);
647   compilation_result->input_mapping = remaining_params;
648   return status;
649 }
650 
GraphToModule(const Graph & graph,llvm::ArrayRef<std::string> control_rets,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,mlir::MLIRContext * context,mlir::OwningModuleRef * module)651 Status GraphToModule(const Graph& graph,
652                      llvm::ArrayRef<std::string> control_rets,
653                      const FunctionLibraryDefinition& flib_def,
654                      const GraphDebugInfo& debug_info,
655                      mlir::MLIRContext* context,
656                      mlir::OwningModuleRef* module) {
657   mlir::DialectRegistry registry;
658   RegisterDialects(registry);
659   context->appendDialectRegistry(registry);
660   GraphImportConfig config;
661   config.graph_as_function = true;
662   config.control_outputs = control_rets;
663   // Disable shape inference during import as some TensorFlow op fails during
664   // shape inference with dynamic shaped operands. This in turn causes the
665   // import to fail. Shape inference during import is going to be removed and
666   // the shape inference pass is run early in the pass pipeline, shape inference
667   // during import is not necessary.
668   config.enable_shape_inference = false;
669   auto module_or =
670       ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
671   if (!module_or.ok()) return module_or.status();
672 
673   *module = std::move(module_or.ValueOrDie());
674 
675   return Status::OK();
676 }
677 
BuildHloFromGraph(const Graph & graph,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)678 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
679                          llvm::ArrayRef<xla::XlaOp> xla_params,
680                          std::vector<xla::XlaOp>& returns,
681                          llvm::ArrayRef<XlaArgument> args,
682                          llvm::ArrayRef<std::string> control_rets,
683                          llvm::StringRef device_type,
684                          const FunctionLibraryDefinition& flib_def,
685                          const GraphDebugInfo& debug_info,
686                          llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
687                              custom_legalization_passes) {
688   mlir::MLIRContext context;
689   mlir::OwningModuleRef module;
690   TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
691                                    &context, &module));
692   return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
693                             device_type, custom_legalization_passes);
694 }
695 
CompileGraphToXlaHlo(const Graph & graph,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,bool use_tuple_args,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)696 Status CompileGraphToXlaHlo(
697     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
698     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
699     bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
700     const GraphDebugInfo& debug_info,
701     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
702     XlaCompilationResult* compilation_result,
703     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
704         custom_legalization_passes) {
705   mlir::MLIRContext context;
706   mlir::OwningModuleRef module;
707   TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
708                                    &context, &module));
709   return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
710                               /*use_return_tuple=*/true,
711                               shape_representation_fn, compilation_result,
712                               custom_legalization_passes);
713 }
714 
715 }  // namespace tensorflow
716