1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Identifier.h"  // from @llvm-project
30 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
31 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
32 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37 
38 namespace mlir {
39 namespace tf_saved_model {
40 
41 //===----------------------------------------------------------------------===//
42 // Utilities
43 //===----------------------------------------------------------------------===//
44 
IsStrArrayAttr(Attribute attr)45 static bool IsStrArrayAttr(Attribute attr) {
46   auto array = attr.dyn_cast<ArrayAttr>();
47   if (!array) return false;
48 
49   return llvm::all_of(array,
50                       [](Attribute attr) { return attr.isa<StringAttr>(); });
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // TensorFlowSavedModelDialect Op's
55 //===----------------------------------------------------------------------===//
56 
VerifyTensorTypesCompatible(Type t1,Type t2)57 LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
58   if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
59     return failure();
60   }
61   return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
62 }
63 
Verify(GlobalTensorOp global_tensor)64 static LogicalResult Verify(GlobalTensorOp global_tensor) {
65   if (failed(VerifyTensorTypesCompatible(
66           global_tensor.type(), global_tensor.value().Attribute::getType()))) {
67     return global_tensor.emitError() << "'type' and 'value' attributes should "
68                                         "have compatible tensor types";
69   }
70   if (!global_tensor.is_mutable()) {
71     if (!global_tensor.type().cast<TensorType>().hasStaticShape()) {
72       return global_tensor.emitError()
73              << "'type' attribute for immutable 'tf_saved_model.global_tensor' "
74                 "should have a static shape";
75     }
76   }
77   return success();
78 }
79 
Verify(SessionInitializerOp session_initializer)80 static LogicalResult Verify(SessionInitializerOp session_initializer) {
81   mlir::SymbolTable symbol_table(
82       session_initializer->getParentOfType<ModuleOp>());
83 
84   for (auto sym_ref : session_initializer.initializers()) {
85     auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
86         sym_ref.cast<FlatSymbolRefAttr>().getValue());
87 
88     if (!init_func_op)
89       return session_initializer.emitOpError()
90              << "the initializer function does not exist";
91 
92     if (!init_func_op.getType().getResults().empty())
93       return session_initializer.emitOpError()
94              << "the initializer function should have no output";
95 
96     auto exported_names = GetExportedNames(init_func_op);
97 
98     if (exported_names.empty())
99       return session_initializer.emitOpError()
100              << "the initializer function should be exported";
101 
102     if (exported_names.size() != 1)
103       return session_initializer.emitOpError()
104              << "the initializer function should have only one exported names";
105   }
106 
107   return success();
108 }
109 
110 }  // namespace tf_saved_model
111 }  // namespace mlir
112 
113 #define GET_OP_CLASSES
114 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
115 
116 namespace mlir {
117 namespace tf_saved_model {
118 
119 //===----------------------------------------------------------------------===//
120 // TensorFlowSavedModelDialect Dialect
121 //===----------------------------------------------------------------------===//
122 
TensorFlowSavedModelDialect(MLIRContext * context)123 TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
124     : Dialect(/*name=*/"tf_saved_model", context,
125               TypeID::get<TensorFlowSavedModelDialect>()) {
126   // The TensorFlow Dialect is needed in the verifier and other routines
127   // associated to this dialect. It makes little sense anyway to use the
128   // SavedModel dialect without the TensorFlow Dialect.
129   context->loadDialect<TF::TensorFlowDialect>();
130 
131   addOperations<
132 #define GET_OP_LIST
133 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
134       >();
135 }
136 
VerifyIndexPath(Operation * op,NamedAttribute named_attr)137 static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
138   auto attr = named_attr.second.dyn_cast<ArrayAttr>();
139   if (!attr) {
140     return op->emitError()
141            << "'tf_saved_model.index_path' attribute should be an ArrayAttr";
142   }
143   for (auto element : attr) {
144     if (element.isa<StringAttr>()) {
145       continue;
146     }
147     if (auto integer = element.dyn_cast<IntegerAttr>()) {
148       if (integer.getValue().getBitWidth() == 64) {
149         continue;
150       }
151     }
152     return op->emitError() << "'tf_saved_model.index_path' elements should "
153                               "be strings or 64-bit integers";
154   }
155   return mlir::success();
156 }
157 
GetBoundInputArgTypeFor(mlir::Operation * op)158 Type GetBoundInputArgTypeFor(mlir::Operation *op) {
159   if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
160     auto type = global_tensor.type().cast<TensorType>();
161     return RankedTensorType::get(
162         {}, TF::ResourceType::get({type}, type.getContext()));
163   }
164 
165   if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
166     return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
167   }
168 
169   op->emitError() << "unknown symbol operation";
170   return {};
171 }
172 
VerifyBoundInputArgType(Operation * op_for_diagnostics,Type arg_type,mlir::Operation * symbol_op)173 static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
174                                              Type arg_type,
175                                              mlir::Operation *symbol_op) {
176   auto expected_type = GetBoundInputArgTypeFor(symbol_op);
177   if (!expected_type) return failure();
178 
179   if (arg_type != expected_type) {
180     return op_for_diagnostics->emitError()
181            << "bound input with type " << arg_type << " expected to have type "
182            << expected_type;
183   }
184   return success();
185 }
186 
verifyRegionArgAttribute(Operation * op,unsigned region_index,unsigned arg_index,NamedAttribute named_attr)187 LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
188     Operation *op, unsigned region_index, unsigned arg_index,
189     NamedAttribute named_attr) {
190   if (named_attr.first == "tf_saved_model.bound_input") {
191     if (!named_attr.second.isa<FlatSymbolRefAttr>()) {
192       return op->emitError() << "'tf_saved_model.bound_input' attribute should "
193                                 "be a FlatSymbolRefAttr";
194     }
195     auto symbol_name = named_attr.second.cast<FlatSymbolRefAttr>().getValue();
196     auto module = op->getParentOfType<ModuleOp>();
197     mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
198     if (!symbol_op) {
199       return op->emitError() << "'tf_saved_model.bound_input' attribute must "
200                                 "reference a valid symbol, got invalid symbol '"
201                              << symbol_name << "'";
202     }
203     auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
204     return VerifyBoundInputArgType(op, arg_type, symbol_op);
205   }
206   if (named_attr.first == "tf_saved_model.index_path") {
207     return VerifyIndexPath(op, named_attr);
208   }
209 
210   return op->emitError() << "unknown tf_saved_model dialect arg attribute '"
211                          << named_attr.first << "'";
212 }
213 
verifyRegionResultAttribute(Operation * op,unsigned region_index,unsigned result_index,NamedAttribute named_attr)214 LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute(
215     Operation *op, unsigned region_index, unsigned result_index,
216     NamedAttribute named_attr) {
217   if (named_attr.first == "tf_saved_model.index_path") {
218     return VerifyIndexPath(op, named_attr);
219   }
220 
221   return op->emitError() << "unknown tf_saved_model dialect result attribute '"
222                          << named_attr.first << "'";
223 }
224 
HasAnyTfSavedModelArgAttr(FuncOp func)225 static bool HasAnyTfSavedModelArgAttr(FuncOp func) {
226   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
227     if (func.getArgAttr(i, "tf_saved_model.index_path") ||
228         func.getArgAttr(i, "tf_saved_model.bound_input")) {
229       return true;
230     }
231   }
232   for (int i = 0, e = func.getNumResults(); i < e; i++) {
233     if (func.getResultAttr(i, "tf_saved_model.index_path") ||
234         func.getResultAttr(i, "tf_saved_model.bound_input")) {
235       return true;
236     }
237   }
238   return false;
239 }
240 
VerifySavedModelModule(ModuleOp module,TensorFlowSavedModelDialect * dialect)241 static LogicalResult VerifySavedModelModule(
242     ModuleOp module, TensorFlowSavedModelDialect *dialect) {
243   auto exported_names_ident =
244       Identifier::get("tf_saved_model.exported_names", dialect->getContext());
245   // Check that there are no duplicated exported_names.
246   DenseMap<StringRef, Operation *> exported_name_to_op;
247   for (auto &op : module) {
248     auto attr = op.getAttr(exported_names_ident);
249     if (!attr) continue;
250     // If this verifier is called before we verify the
251     // 'tf_saved_model.exported_names' attribute, then it might be invalid.
252     // Forward to the dialect's verification to establish that precondition.
253     if (failed(dialect->verifyOperationAttribute(
254             &op, {exported_names_ident, attr}))) {
255       return failure();
256     }
257     for (auto str : attr.cast<ArrayAttr>()) {
258       auto exported_name = str.cast<StringAttr>().getValue();
259       auto p = exported_name_to_op.insert({exported_name, &op});
260       if (!p.second) {
261         return op.emitError()
262             .append("duplicate exported name '", exported_name, "'")
263             .attachNote(p.first->getSecond()->getLoc())
264             .append("previously seen here");
265       }
266     }
267   }
268   for (auto func : module.getOps<FuncOp>()) {
269     const bool is_exported = IsExported(func);
270 
271     if (is_exported && !func.isPublic()) {
272       return func.emitError()
273              << "exported function @" << func.getName() << " should be public";
274     }
275 
276     if (!is_exported && func.isPublic()) {
277       return func.emitError() << "non-exported function @" << func.getName()
278                               << " should be private";
279     }
280 
281     if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
282       return func.emitError() << "can only apply 'tf_saved_model' argument "
283                                  "attributes to exported functions";
284     }
285   }
286 
287   auto session_initializers = module.getOps<SessionInitializerOp>();
288   if (!session_initializers.empty() &&
289       !llvm::hasSingleElement(session_initializers)) {
290     return (*++session_initializers.begin()).emitError()
291            << "there must be no more than one session_initializer op";
292   }
293 
294   auto is_init = [&session_initializers](mlir::FuncOp func) {
295     if (session_initializers.empty()) return false;
296     auto init_syms = (*session_initializers.begin()).initializers();
297     return std::any_of(
298         init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
299           return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
300         });
301   };
302 
303   SymbolTable symbol_table(module);
304   auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
305   if (!symbol_uses.hasValue()) {
306     return module.emitError() << "modules with 'tf_saved_model.semantics' must "
307                                  "have analyzable symbol uses";
308   }
309   for (auto symbol_use : *symbol_uses) {
310     auto func = symbol_table.lookup<FuncOp>(
311         symbol_use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
312     if (func && IsExported(func)) {
313       // If it is an init function, then it can be used by the unique
314       // session_initializer op.
315       if (is_init(func) &&
316           llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
317         continue;
318 
319       return symbol_use.getUser()
320           ->emitError("exported function cannot be internally referenced")
321           .attachNote(func.getLoc())
322           .append("references this exported function");
323     }
324   }
325   return success();
326 }
327 
VerifyExportedFunc(FuncOp func)328 LogicalResult VerifyExportedFunc(FuncOp func) {
329   bool reached_bound_inputs = false;
330   auto module = func->getParentOfType<ModuleOp>();
331   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
332     if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
333       reached_bound_inputs = true;
334       continue;
335     }
336     if (func.getArgAttr(i, "tf_saved_model.index_path")) {
337       if (reached_bound_inputs) {
338         return func.emitError()
339                << "all 'tf_saved_model.index_path' arg attributes should "
340                   "precede all 'tf_saved_model.bound_input' arg attributes";
341       }
342       continue;
343     }
344     if (func.getArgAttr(i, "tf.resource_name")) {
345       if (module->getAttr("tf_saved_model.under_construction")) continue;
346       return func.emitError() << "'tf.resource_name' attribute is not allowed "
347                                  "unless it is being under construction";
348     }
349     return func.emitError()
350            << "all arguments should have 'tf_saved_model.index_path', "
351               "'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
352   }
353   llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
354   for (int i = 0, e = func.getNumArguments(); i < e; i++) {
355     if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
356             i, "tf_saved_model.bound_input")) {
357       if (!unique_bound_inputs.insert(attr.getValue()).second) {
358         if (module->getAttr("tf_saved_model.under_construction")) continue;
359         return func.emitError()
360                << "duplicate 'tf_saved_model.bound_input' binding";
361       }
362     }
363   }
364 
365   for (int i = 0, e = func.getNumResults(); i < e; i++) {
366     if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
367       return func.emitError() << "all results should have "
368                                  "'tf_saved_model.index_path' attributes";
369     }
370   }
371 
372   return success();
373 }
374 
verifyOperationAttribute(Operation * op,NamedAttribute named_attr)375 LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
376     Operation *op, NamedAttribute named_attr) {
377   if (named_attr.first == "tf_saved_model.exported_names") {
378     if (!isa<FuncOp, GlobalTensorOp>(op)) {
379       return op->emitError() << "'tf_saved_model.exported_names' must be on a "
380                                 "'func' or 'tf_saved_model.global_tensor' op";
381     }
382     if (!IsStrArrayAttr(named_attr.second)) {
383       return op->emitError()
384              << "'tf_saved_model.exported_names' must be an array of strings";
385     }
386     if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) {
387       return op->emitError()
388              << "'tf_saved_model.exported_names' must be on an op "
389                 "whose immediate parent has attribute "
390                 "'tf_saved_model.semantics'";
391     }
392     if (auto func = dyn_cast<FuncOp>(op)) {
393       if (failed(VerifyExportedFunc(func))) {
394         return failure();
395       }
396     }
397     return success();
398   }
399   if (named_attr.first == "tf_saved_model.semantics") {
400     auto module = dyn_cast<ModuleOp>(op);
401     if (!module) {
402       return op->emitError() << "'tf_saved_model.semantics' must "
403                                 "be on a module op";
404     }
405     return VerifySavedModelModule(module, this);
406   }
407   if (named_attr.first == "tf_saved_model.under_construction") {
408     return success();
409   }
410 
411   return op->emitError() << "unknown tf_saved_model dialect attribute '"
412                          << named_attr.first << "'";
413 }
414 
GetExportedNames(Operation * op)415 SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
416   SmallVector<StringRef, 2> ret;
417   auto exported_names =
418       op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
419   if (exported_names) {
420     for (auto name : exported_names) {
421       ret.push_back(name.cast<StringAttr>().getValue());
422     }
423   }
424   return ret;
425 }
426 
IsExported(Operation * op)427 bool IsExported(Operation *op) {
428   auto exported_names =
429       op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
430   return exported_names && !exported_names.empty();
431 }
432 
HasTfSavedModelSemantics(ModuleOp module)433 bool HasTfSavedModelSemantics(ModuleOp module) {
434   return module->getAttr("tf_saved_model.semantics") != nullptr;
435 }
436 
LookupBoundInput(FuncOp func,int arg_index,const SymbolTable & symbol_table)437 Operation *LookupBoundInput(FuncOp func, int arg_index,
438                             const SymbolTable &symbol_table) {
439   auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
440       arg_index, "tf_saved_model.bound_input");
441   if (!attr) return nullptr;
442   return symbol_table.lookup(attr.getValue());
443 }
444 
GetSessionInitializerOp(mlir::ModuleOp op)445 SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
446   auto initializers = op.getOps<SessionInitializerOp>();
447   if (initializers.empty()) return {};
448   return *initializers.begin();
449 }
450 
451 class OptimizeSessionInitializerPattern
452     : public OpRewritePattern<SessionInitializerOp> {
453  public:
454   using OpRewritePattern::OpRewritePattern;
455 
matchAndRewrite(SessionInitializerOp op,PatternRewriter & rewriter) const456   LogicalResult matchAndRewrite(SessionInitializerOp op,
457                                 PatternRewriter &rewriter) const override {
458     SymbolTable symbol_table(op->getParentOfType<ModuleOp>());
459 
460     SmallVector<FuncOp, 2> to_remove;
461     SmallVector<mlir::Attribute, 2> to_keep;
462     for (auto sym_ref : op.initializers()) {
463       auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
464           sym_ref.cast<FlatSymbolRefAttr>().getValue());
465 
466       // The init function can only be referenced from the SessionInitializerOp.
467       // And there is at most one SessionInitializerOp in the module. So if both
468       // ops have no other uses or have one NoOp only, they can be simply
469       // erased.
470       auto &operations = init_func_op.front().getOperations();
471       if ((operations.size() == 1 &&
472            operations.front().hasTrait<OpTrait::IsTerminator>()) ||
473           (operations.size() == 2 &&
474            dyn_cast<mlir::TF::NoOp>(operations.front()) &&
475            operations.back().hasTrait<OpTrait::IsTerminator>())) {
476         to_remove.push_back(init_func_op);
477       } else {
478         to_keep.push_back(sym_ref);
479       }
480     }
481 
482     for (auto func_op : to_remove) rewriter.eraseOp(func_op);
483 
484     if (to_keep.empty())
485       rewriter.eraseOp(op);
486     else
487       op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
488 
489     return success();
490   }
491 };
492 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)493 void SessionInitializerOp::getCanonicalizationPatterns(
494     OwningRewritePatternList &results, MLIRContext *context) {
495   results.insert<OptimizeSessionInitializerPattern>(context);
496 }
497 
GetSessionInitializerExportedName(ModuleOp op)498 SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
499   auto session_initializer_op = GetSessionInitializerOp(op);
500   if (!session_initializer_op) return {};
501 
502   SymbolTable symbol_table(op);
503 
504   SmallVector<StringRef, 2> results;
505   for (auto sym_ref : session_initializer_op.initializers()) {
506     auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
507         sym_ref.cast<FlatSymbolRefAttr>().getValue());
508     auto exported_names = GetExportedNames(init_func_op);
509     assert(exported_names.size() == 1);
510     results.push_back(exported_names[0]);
511   }
512 
513   return results;
514 }
515 
516 }  // namespace tf_saved_model
517 }  // namespace mlir
518