1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
17
18 #include "absl/types/optional.h"
19 #include "absl/types/variant.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Dialect.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/OpDefinition.h" // from @llvm-project
35 #include "mlir/Transforms/Passes.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
55 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
56 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
57 #include "tensorflow/compiler/tf2xla/shape_util.h"
58 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
59 #include "tensorflow/compiler/xla/shape.h"
60 #include "tensorflow/compiler/xla/xla_data.pb.h"
61 #include "tensorflow/core/framework/tensor_shape.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/tpu/tpu_defs.h"
64
65 namespace tensorflow {
66 namespace {
67
68 // Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape,
69 // that is converted to a TensorShape.
GetTensorShapeFromXlaArgument(const XlaArgument & arg)70 StatusOr<TensorShape> GetTensorShapeFromXlaArgument(const XlaArgument& arg) {
71 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
72 TensorShape arg_shape;
73 TF_RETURN_IF_ERROR(
74 XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &arg_shape));
75 return arg_shape;
76 } else {
77 return absl::get<TensorShape>(arg.shape);
78 }
79 }
80
81 // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
GetXlaInputShapes(mlir::ModuleOp module,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,std::vector<xla::Shape> * xla_input_shapes)82 Status GetXlaInputShapes(
83 mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
84 bool use_tuple_args,
85 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
86 std::vector<xla::Shape>* xla_input_shapes) {
87 xla_input_shapes->clear();
88
89 mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
90 TF_RET_CHECK(main_func != nullptr) << "No main function found";
91 mlir::FunctionType func_type = main_func.getType();
92
93 int num_args = func_type.getNumInputs();
94 xla_input_shapes->reserve(num_args);
95
96 std::vector<xla::Shape> individual_arg_shapes;
97 individual_arg_shapes.reserve(num_args);
98 for (int i = 0; i < num_args; ++i) {
99 individual_arg_shapes.emplace_back();
100 xla::Shape& xla_shape = individual_arg_shapes.back();
101
102 DataType dtype;
103 TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
104 TF_ASSIGN_OR_RETURN(xla_shape,
105 shape_representation_fn(arg_shapes[i].shape, dtype,
106 /*use_fast_memory=*/false));
107
108 // Rewrite layout with sharding, if sharding is set.
109 auto sharding =
110 main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
111 if (!sharding) continue;
112
113 absl::optional<xla::HloSharding> arg_sharding;
114 xla::OpSharding op_sharding;
115 if (!op_sharding.ParseFromString(sharding.getValue().str()))
116 return errors::InvalidArgument("failed to parse argument sharding ", i,
117 " '", sharding.getValue().str(), "'");
118
119 TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding));
120 TF_RETURN_IF_ERROR(
121 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
122 shape_representation_fn, &xla_shape));
123 }
124 if (use_tuple_args) {
125 xla_input_shapes->push_back(
126 xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
127 } else {
128 *xla_input_shapes = individual_arg_shapes;
129 }
130 return Status::OK();
131 }
132
133 // Calculates computation output shape and build OutputDescription for each
134 // output based on static shapes in MLIR module. If an output is a resource
135 // write, `resource_updates` is populated insead of `outputs` for that output.
GetOutputInfo(mlir::ModuleOp module,bool use_resource_updates_for_aliases,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_output_shape,std::vector<XlaOutputDescription> * outputs,std::vector<XlaResourceUpdate> * resource_updates)136 Status GetOutputInfo(
137 mlir::ModuleOp module, bool use_resource_updates_for_aliases,
138 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
139 xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
140 std::vector<XlaResourceUpdate>* resource_updates) {
141 auto shape_representation_fn_no_fast_memory =
142 [shape_representation_fn](const TensorShape& shape, DataType dtype) {
143 return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
144 };
145
146 mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
147 mlir::FunctionType func_type = main_func.getType();
148
149 outputs->clear();
150 outputs->reserve(func_type.getNumResults());
151 resource_updates->reserve(func_type.getNumResults());
152
153 std::vector<xla::Shape> shapes;
154 shapes.reserve(func_type.getNumResults());
155
156 llvm::SmallDenseMap<unsigned, unsigned> output_to_input_alias;
157 for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
158 if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
159 i, "tf.aliasing_output"))
160 output_to_input_alias[aliasing_output.getInt()] = i;
161
162 for (auto type_and_idx : llvm::enumerate(func_type.getResults())) {
163 TF_ASSIGN_OR_RETURN(
164 xla::Shape shape,
165 xla::TypeToShape(type_and_idx.value(),
166 shape_representation_fn_no_fast_memory));
167 auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
168 shapes.push_back(shape);
169
170 auto it = output_to_input_alias.find(type_and_idx.index());
171 if (it != output_to_input_alias.end() && use_resource_updates_for_aliases) {
172 // Add resource write.
173 resource_updates->emplace_back();
174 XlaResourceUpdate& resource_update = resource_updates->back();
175 resource_update.input_index = it->getSecond();
176 resource_update.modified = true;
177 TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
178 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
179 continue;
180 }
181 // Construct OutputDescription for result.
182 outputs->emplace_back();
183 XlaOutputDescription& out_desc = outputs->back();
184 TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
185 // TODO(ycao): Support constant output.
186 out_desc.is_constant = false;
187 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &out_desc.shape));
188 // Input_index is only meaningful for resource output. Setting it to
189 // meaningless value -1 for non resource outputs.
190 out_desc.input_index =
191 it != output_to_input_alias.end() ? it->getSecond() : -1;
192 // MLIR-based TF-Compiler bridge doesn't support tensorlist output yet.
193 // TODO(ycao): Support tensorlist-type output.
194 out_desc.is_tensor_list = false;
195 }
196
197 // XLA computation always uses Tuple shape.
198 *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
199 return Status::OK();
200 }
201
202 // Creates a vector that maps from the parameters of the XLA computation to
203 // their original argument positions.
204 // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
205 // inputs are known constants. Therefore, the input mapping between input to
206 // computation arguments is a trivial in-order 1-1 mapping.
207 // TODO(ycao): Support computation with compile-time constant, which requires
208 // non-trivial input mapping as implemented now.
GetInputMappingForMlir(int num_inputs,std::vector<int> * input_mapping)209 void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
210 input_mapping->resize(num_inputs, 0);
211 std::iota(input_mapping->begin(), input_mapping->end(), 0);
212 }
213
RegisterDialects(mlir::DialectRegistry & registry)214 static void RegisterDialects(mlir::DialectRegistry& registry) {
215 mlir::RegisterAllTensorFlowDialects(registry);
216 mlir::mhlo::registerAllMhloDialects(registry);
217 }
218
219 // Checks if functions can be inlined after TF -> HLO legalization. Currently
220 // TPU's are supported, to follow the behavior of inlining functions via the
221 // Graph based bridge in the TPUCompile op kernel.
CanInlineFunctionsPostLegalization(llvm::StringRef device_type)222 bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
223 return device_type == DEVICE_TPU_XLA_JIT;
224 }
225
226 } // namespace
227
RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,mlir::ModuleOp module)228 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
229 mlir::ModuleOp module) {
230 auto producer_or = GetTfGraphProducerVersion(module);
231 if (!producer_or.ok()) return producer_or.status();
232 int64_t producer_version = producer_or.ValueOrDie();
233
234 llvm::SmallVector<int64_t, 16> shape_backing;
235 llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> arg_shapes_copy;
236 {
237 // Convert arg_shapes to a mlir friendly format.
238 size_t count = 0;
239 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
240 if (tensor_resource_shape.is_resource) continue;
241 count += tensor_resource_shape.shape.dims();
242 }
243 shape_backing.resize(count);
244 arg_shapes_copy.reserve(arg_shapes.size());
245 size_t offset = 0;
246 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
247 if (tensor_resource_shape.is_resource) {
248 arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
249 continue;
250 }
251 size_t start = offset;
252 for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
253 shape_backing[offset] = dim.size;
254 ++offset;
255 }
256 if (offset == start) {
257 arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
258 } else {
259 arg_shapes_copy.push_back(
260 llvm::ArrayRef<int64_t>(&shape_backing[start], offset - start));
261 }
262 }
263 }
264
265 auto main_func = module.lookupSymbol<mlir::FuncOp>("main");
266
267 mlir::StatusScopedDiagnosticHandler error_handler(module.getContext());
268 mlir::LogicalResult result = mlir::TF::InferShapeForFunction(
269 main_func, arg_shapes_copy, producer_version);
270
271 if (failed(result)) {
272 return error_handler.Combine(
273 errors::Internal("MLIR Shape refinement failed"));
274 }
275 return Status::OK();
276 }
277
CreateConvertMlirToXlaHloPipeline(mlir::OpPassManager & pm,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)278 void CreateConvertMlirToXlaHloPipeline(
279 mlir::OpPassManager& pm, llvm::StringRef device_type,
280 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
281 custom_legalization_passes) {
282 pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
283 pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateDropWhileShapeInvariantPass());
284 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
285 // The SCCP pass performs constant propagation across the IR, which, for
286 // example, propagates constant arguments into callee functions.
287 pm.addPass(mlir::createSCCPPass());
288 // Guarantee all functions have one use, which enables shape inference.
289 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
290 // Run shape inference pass before tensorlist decomposition to get buffer
291 // shape of uninitialized TensorLists.
292 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
293 pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
294 pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
295 pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
296 pm.addNestedPass<mlir::FuncOp>(
297 mlir::TFDevice::CreateDecomposeResourceOpsPass());
298 pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
299 pm.addPass(mlir::createSymbolDCEPass());
300 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
301 // TODO(b/171426148): We cannot completely remove region to functional control
302 // flow conversion from this pipeline yet as it causes some unit tests to
303 // fail.
304 pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
305 // LegalizeTFControlFlow encapsulates arguments for control flow operations
306 // with a tuple argument which break the assumption of resource lifting
307 // inside PromoteResourcesToArgs.
308 pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
309
310 pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
311 pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
312 /*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
313 /*tf2xla_fallback_device_type=*/device_type));
314 for (auto& target_pass : custom_legalization_passes) {
315 pm.addNestedPass<mlir::FuncOp>(std::move(target_pass));
316 }
317 pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass());
318 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
319 // Run shape inference pass to propagate shapes through tensor_cast operations
320 // from static to dynamic shapes. This could be generated if the shape
321 // inference was originally missing in a TF op but the corresponding HLO op
322 // had static shape after lowering.
323 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
324 // Run LegalizeTFPass again because the previous legalization passes can
325 // expose more graph pruning and canonicalization opportunities that are
326 // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
327 // invocation.
328 pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
329 /*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
330 /*tf2xla_fallback_device_type=*/device_type));
331
332 if (CanInlineFunctionsPostLegalization(device_type))
333 pm.addPass(mlir::createInlinerPass());
334
335 // In order to export to XLA, we must sink constants to control flow regions,
336 // since XLA uses functional control flow.
337 pm.addNestedPass<mlir::FuncOp>(
338 mlir::mhlo::createSinkConstantsToControlFlowPass());
339 }
340
LegalizeToHlo(mlir::ModuleOp module_op,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)341 Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
342 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
343 custom_legalization_passes) {
344 mlir::PassManager tf2xla(module_op.getContext());
345 applyTensorflowAndCLOptions(tf2xla);
346 CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
347 custom_legalization_passes);
348
349 if (VLOG_IS_ON(1))
350 tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op);
351 if (VLOG_IS_ON(2)) {
352 // Print the whole module after each pass which requires disabling
353 // multi-threading as well.
354 module_op.getContext()->disableMultithreading();
355 tf2xla.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
356 /*print_module_scope=*/true));
357 }
358
359 // Make sure we catch any error reported by MLIR and forward it to the TF
360 // error reporting system. Report a generic error if pass manager failed
361 // without emitting a diagnostic.
362 mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
363
364 if (failed(tf2xla.run(module_op))) {
365 return error_handler.Combine(
366 errors::InvalidArgument("TF to XLA legalization failed"));
367 }
368
369 if (VLOG_IS_ON(1))
370 tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op);
371
372 return Status::OK();
373 }
374
BuildHloFromTfInner(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)375 Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
376 llvm::ArrayRef<xla::XlaOp> xla_params,
377 std::vector<xla::XlaOp>& returns,
378 llvm::StringRef device_type,
379 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
380 custom_legalization_passes) {
381 TF_RETURN_IF_ERROR(
382 LegalizeToHlo(module_op, device_type, custom_legalization_passes));
383
384 mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
385 return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
386 }
387
ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,llvm::StringRef device_type,xla::XlaComputation * xla_computation,bool use_tuple_args,bool return_tuple,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)388 Status ConvertMLIRToXlaComputation(
389 mlir::ModuleOp module_op, llvm::StringRef device_type,
390 xla::XlaComputation* xla_computation, bool use_tuple_args,
391 bool return_tuple,
392 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
393 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
394 custom_legalization_passes) {
395 TF_RETURN_IF_ERROR(
396 LegalizeToHlo(module_op, device_type, custom_legalization_passes));
397
398 xla::HloProto hlo_proto;
399 TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
400 use_tuple_args, return_tuple,
401 shape_representation_fn));
402 *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
403 return Status::OK();
404 }
405
CompileMlirSetup(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,XlaHelpers::ShapeRepresentationFn * shape_representation_fn)406 Status CompileMlirSetup(
407 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
408 XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
409 // Use arg_shapes to improve the mlir type information of `main` in module_op.
410 TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
411
412 if (VLOG_IS_ON(2))
413 tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op);
414
415 if (!*shape_representation_fn)
416 *shape_representation_fn = IdentityShapeRepresentationFn();
417
418 return Status::OK();
419 }
420
BuildHloFromTf(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)421 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
422 llvm::ArrayRef<xla::XlaOp> xla_params,
423 std::vector<xla::XlaOp>& returns,
424 llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
425 llvm::StringRef device_type,
426 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
427 custom_legalization_passes) {
428 if (VLOG_IS_ON(2))
429 tensorflow::DumpMlirOpToFile("build_hlo_tf_before", module_op);
430
431 XlaHelpers::ShapeRepresentationFn shape_representation_fn;
432 TF_RETURN_IF_ERROR(
433 CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
434
435 // Convert MLIR module to XLA HLO proto contained in XlaComputation.
436 TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
437 returns, device_type,
438 custom_legalization_passes));
439
440 if (VLOG_IS_ON(2))
441 tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op);
442
443 return Status::OK();
444 }
445
PopulateResultIOInfo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,bool use_resource_updates_for_aliases,XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result)446 Status PopulateResultIOInfo(
447 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
448 bool use_tuple_args, bool use_resource_updates_for_aliases,
449 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
450 XlaCompilationResult* compilation_result) {
451 // Construct mapping from XlaComputation's arg to input edges of execute
452 // node.
453 GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
454
455 // Compute all input shapes.
456 TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
457 shape_representation_fn,
458 &compilation_result->xla_input_shapes));
459
460 // Compute all output descriptions and resource writes
461 return GetOutputInfo(
462 module_op, use_resource_updates_for_aliases, shape_representation_fn,
463 &compilation_result->xla_output_shape, &compilation_result->outputs,
464 &compilation_result->resource_updates);
465 }
466
CompileMlirToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool use_return_tuple,bool use_resource_updates_for_aliases,XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)467 Status CompileMlirToXlaHlo(
468 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
469 llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
470 bool use_resource_updates_for_aliases,
471 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
472 XlaCompilationResult* compilation_result,
473 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
474 custom_legalization_passes) {
475 TF_RETURN_IF_ERROR(
476 CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
477
478 // Convert MLIR module to XLA HLO proto contained in XlaComputation.
479 compilation_result->computation = std::make_shared<xla::XlaComputation>();
480 TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
481 module_op, device_type, compilation_result->computation.get(),
482 use_tuple_args, use_return_tuple, shape_representation_fn,
483 custom_legalization_passes));
484
485 return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
486 use_resource_updates_for_aliases,
487 shape_representation_fn, compilation_result);
488 }
489
CompileSerializedMlirToXlaHlo(llvm::StringRef mlir_module_string,llvm::ArrayRef<TensorShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)490 Status CompileSerializedMlirToXlaHlo(
491 llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
492 llvm::StringRef device_type, bool use_tuple_args,
493 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
494 XlaCompilationResult* compilation_result,
495 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
496 custom_legalization_passes) {
497 mlir::DialectRegistry mlir_registry;
498 RegisterDialects(mlir_registry);
499 mlir::MLIRContext mlir_context(mlir_registry);
500 mlir::OwningModuleRef mlir_module;
501
502 TF_RETURN_IF_ERROR(
503 DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module));
504 llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
505 tensor_or_resource_shapes.reserve(arg_shapes.size());
506 for (const auto& arg_shape : arg_shapes)
507 tensor_or_resource_shapes.push_back({arg_shape});
508 return CompileMlirToXlaHlo(
509 mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
510 /*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
511 shape_representation_fn, compilation_result, custom_legalization_passes);
512 }
513
514 // Rewrites the given module with specified args. For each of the constant args,
515 // it gets inlined in the "main' function and the corresponding argument is
516 // removed from the signature. For resource args, their subtypes are populated.
517 // Returns the original indices for the other arguments on success.
RewriteWithArgs(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args)518 static StatusOr<std::vector<int>> RewriteWithArgs(
519 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args) {
520 mlir::FuncOp main_fn = module_op.lookupSymbol<mlir::FuncOp>("main");
521 std::vector<int> params;
522
523 bool has_resource_args = false;
524 auto builder = mlir::OpBuilder(main_fn.getBody());
525 std::vector<int> args_to_erase;
526 for (int idx = 0; idx < args.size(); idx++) {
527 const XlaArgument& xla_arg = args[idx];
528 mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
529 if (xla_arg.kind == XlaArgument::kResource) {
530 mlir::Type element_type;
531 if (xla_arg.type == DT_INVALID) {
532 return errors::Unimplemented(absl::StrCat(
533 "Argument ", idx,
534 " is an uninitialized resource variable which is currently"
535 " unsupported in the MLIR-based TPU bridge"));
536 }
537 TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
538 TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
539 GetTensorShapeFromXlaArgument(xla_arg));
540 auto resource_shape = arg_shape.dim_sizes();
541 llvm::SmallVector<int64_t, 4> resource_subtype_shape(
542 resource_shape.begin(), resource_shape.end());
543 auto resource_subtype =
544 mlir::RankedTensorType::get(resource_subtype_shape, element_type);
545 auto resource_type =
546 mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
547
548 auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
549 if (tensor_type.hasRank()) {
550 mlir_arg.setType(
551 mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
552 } else {
553 mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
554 }
555 has_resource_args = true;
556 }
557 if (xla_arg.kind != XlaArgument::kConstant) {
558 params.push_back(idx);
559 continue;
560 }
561
562 TF_ASSIGN_OR_RETURN(auto value_attr,
563 ConvertTensor(xla_arg.constant_value, &builder));
564 // TODO(hinsu): Use the actual location of the constant.
565 auto constant = builder.create<mlir::TF::ConstOp>(
566 mlir::UnknownLoc::get(module_op.getContext()), value_attr);
567 mlir_arg.replaceAllUsesWith(constant);
568 args_to_erase.push_back(idx);
569 }
570
571 if (has_resource_args) {
572 llvm::SmallVector<mlir::Type, 4> updated_argument_types;
573 updated_argument_types.reserve(main_fn.getNumArguments());
574 for (mlir::BlockArgument& arg : main_fn.getArguments())
575 updated_argument_types.push_back(arg.getType());
576
577 main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
578 updated_argument_types,
579 main_fn.getType().getResults()));
580 }
581
582 for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
583
584 return params;
585 }
586
CompileGraphSetup(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,std::vector<int> * remaining_params,llvm::SmallVector<TensorOrResourceShape,4> & arg_shapes)587 Status CompileGraphSetup(
588 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
589 std::vector<int>* remaining_params,
590 llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
591 TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
592 arg_shapes.reserve(remaining_params->size());
593 for (unsigned idx : *remaining_params) {
594 const auto& arg = args[idx];
595 TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
596 GetTensorShapeFromXlaArgument(arg));
597 arg_shapes.push_back({arg_shape,
598 /*is_resource=*/arg.kind == XlaArgument::kResource});
599 }
600
601 mlir::PassManager pm(module_op.getContext());
602 applyTensorflowAndCLOptions(pm);
603 mlir::TF::StandardPipelineOptions tf_options;
604 mlir::TF::CreateTFStandardPipeline(pm, tf_options);
605
606 if (VLOG_IS_ON(1))
607 tensorflow::DumpMlirOpToFile("compile_graph_setup_before", module_op);
608 mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
609 if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
610 if (VLOG_IS_ON(1))
611 tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op);
612
613 return Status::OK();
614 }
615
BuildHloFromModule(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)616 Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
617 llvm::ArrayRef<xla::XlaOp> xla_params,
618 std::vector<xla::XlaOp>& returns,
619 llvm::ArrayRef<XlaArgument> args,
620 llvm::StringRef device_type,
621 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
622 custom_legalization_passes) {
623 std::vector<int> remaining_params;
624 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
625 TF_RETURN_IF_ERROR(
626 CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
627 return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
628 device_type, custom_legalization_passes);
629 }
630
CompileGraphToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,bool use_tuple_args,bool use_return_tuple,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)631 Status CompileGraphToXlaHlo(
632 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
633 llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
634 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
635 XlaCompilationResult* compilation_result,
636 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
637 custom_legalization_passes) {
638 std::vector<int> remaining_params;
639 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
640 TF_RETURN_IF_ERROR(
641 CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
642
643 auto status = CompileMlirToXlaHlo(
644 module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
645 /*use_resource_updates_for_aliases=*/true, shape_representation_fn,
646 compilation_result, custom_legalization_passes);
647 compilation_result->input_mapping = remaining_params;
648 return status;
649 }
650
GraphToModule(const Graph & graph,llvm::ArrayRef<std::string> control_rets,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,mlir::MLIRContext * context,mlir::OwningModuleRef * module)651 Status GraphToModule(const Graph& graph,
652 llvm::ArrayRef<std::string> control_rets,
653 const FunctionLibraryDefinition& flib_def,
654 const GraphDebugInfo& debug_info,
655 mlir::MLIRContext* context,
656 mlir::OwningModuleRef* module) {
657 mlir::DialectRegistry registry;
658 RegisterDialects(registry);
659 context->appendDialectRegistry(registry);
660 GraphImportConfig config;
661 config.graph_as_function = true;
662 config.control_outputs = control_rets;
663 // Disable shape inference during import as some TensorFlow op fails during
664 // shape inference with dynamic shaped operands. This in turn causes the
665 // import to fail. Shape inference during import is going to be removed and
666 // the shape inference pass is run early in the pass pipeline, shape inference
667 // during import is not necessary.
668 config.enable_shape_inference = false;
669 auto module_or =
670 ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
671 if (!module_or.ok()) return module_or.status();
672
673 *module = std::move(module_or.ValueOrDie());
674
675 return Status::OK();
676 }
677
BuildHloFromGraph(const Graph & graph,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)678 Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
679 llvm::ArrayRef<xla::XlaOp> xla_params,
680 std::vector<xla::XlaOp>& returns,
681 llvm::ArrayRef<XlaArgument> args,
682 llvm::ArrayRef<std::string> control_rets,
683 llvm::StringRef device_type,
684 const FunctionLibraryDefinition& flib_def,
685 const GraphDebugInfo& debug_info,
686 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
687 custom_legalization_passes) {
688 mlir::MLIRContext context;
689 mlir::OwningModuleRef module;
690 TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
691 &context, &module));
692 return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
693 device_type, custom_legalization_passes);
694 }
695
CompileGraphToXlaHlo(const Graph & graph,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,bool use_tuple_args,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,const XlaHelpers::ShapeRepresentationFn shape_representation_fn,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)696 Status CompileGraphToXlaHlo(
697 const Graph& graph, llvm::ArrayRef<XlaArgument> args,
698 llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
699 bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
700 const GraphDebugInfo& debug_info,
701 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
702 XlaCompilationResult* compilation_result,
703 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
704 custom_legalization_passes) {
705 mlir::MLIRContext context;
706 mlir::OwningModuleRef module;
707 TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
708 &context, &module));
709 return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
710 /*use_return_tuple=*/true,
711 shape_representation_fn, compilation_result,
712 custom_legalization_passes);
713 }
714
715 } // namespace tensorflow
716