1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/Dialect/Traits.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/Matchers.h"  // from @llvm-project
38 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
39 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
40 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
41 #include "mlir/IR/Types.h"  // from @llvm-project
42 #include "mlir/IR/Value.h"  // from @llvm-project
43 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
44 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47 
48 namespace mlir {
49 namespace tf_executor {
50 
51 //===----------------------------------------------------------------------===//
52 // TF Executor Dialect
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 
57 struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
58   using DialectInlinerInterface::DialectInlinerInterface;
59 
60   //===--------------------------------------------------------------------===//
61   // Analysis Hooks
62   //===--------------------------------------------------------------------===//
63 
64   // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_executor::__anonecac450a0111::TensorFlowExecutorInlinerInterface65   bool isLegalToInline(Operation *call, Operation *callable,
66                        bool wouldBeCloned) const final {
67     return true;
68   }
69   // Override the inlining hook to determine if 'src' can be inlined into
70   // 'dest'.
isLegalToInlinemlir::tf_executor::__anonecac450a0111::TensorFlowExecutorInlinerInterface71   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72                        BlockAndValueMapping &value_mapping) const final {
73     // Allow inlining into tf.island regions if the incoming region has a single
74     // block.
75     return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
76            llvm::hasSingleElement(*src);
77   }
78 };
79 
80 struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
81   using DialectFoldInterface::DialectFoldInterface;
82 
83   // Registered hook to check if the given region, which is attached to an
84   // operation that is *not* isolated from above (i.e. no internal regions
85   // reference values defined in an enclosing region), should be used when
86   // materializing constants.
87   // In the executor dialect we materialize inside an island.
shouldMaterializeIntomlir::tf_executor::__anonecac450a0111::TensorFlowExecutorDialectFoldInterface88   bool shouldMaterializeInto(Region *region) const final {
89     return isa<tf_executor::IslandOp>(region->getParentOp());
90   }
91 };
92 
93 }  // namespace
94 
TensorFlowExecutorDialect(MLIRContext * context)95 TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
96     : Dialect(/*name=*/"tf_executor", context,
97               TypeID::get<TensorFlowExecutorDialect>()) {
98   addOperations<
99 #define GET_OP_LIST
100 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
101       >();
102 
103   addInterfaces<TensorFlowExecutorInlinerInterface,
104                 TensorFlowExecutorDialectFoldInterface>();
105 
106   addTypes<ControlType, TokenType>();
107 }
108 
parseType(DialectAsmParser & parser) const109 Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const {
110   StringRef data_type;
111   if (parser.parseKeyword(&data_type)) return Type();
112 
113   if (data_type == "control") return ControlType::get(getContext());
114   if (data_type == "token") return TokenType::get(getContext());
115   parser.emitError(parser.getNameLoc())
116       << "unknown tf_executor type: " << data_type;
117   return nullptr;
118 }
119 
printType(Type type,DialectAsmPrinter & os) const120 void TensorFlowExecutorDialect::printType(Type type,
121                                           DialectAsmPrinter &os) const {
122   if (type.isa<ControlType>()) {
123     os << "control";
124     return;
125   }
126   if (type.isa<TokenType>()) {
127     os << "token";
128     return;
129   }
130   os << "<unknown tf_executor type>";
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Implementation for all the operations defined in ODS (op definition spec).
135 //===----------------------------------------------------------------------===//
136 
137 namespace {
138 
139 // Verifies that every control operands are at the end of the list.
140 // Used by the constraint `ControlOperandsAfterAllData` in ODS.
VerifyControlOperandsAfterAllData(Operation * op)141 LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
142   bool found_control = false;
143   for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
144     if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
145       found_control = true;
146       continue;
147     }
148     if (found_control)
149       return op->emitOpError() << "found non-control operand #" << operand_idx
150                                << " after control operand";
151   }
152   return success();
153 }
154 
155 }  // anonymous namespace
156 
157 //===----------------------------------------------------------------------===//
158 // tf_executor.graph
159 //===----------------------------------------------------------------------===//
160 
GetFetch()161 FetchOp GraphOp::GetFetch() { return llvm::cast<FetchOp>(GetBody().back()); }
162 
163 namespace {
164 
Verify(GraphOp graph)165 LogicalResult Verify(GraphOp graph) {
166   auto *executorDialect = graph->getDialect();
167 
168   if (graph.GetBody().empty())
169     return graph.emitOpError() << "expects a non-empty body";
170 
171   // Only tf_executor dialect operations are allowed to be immediately nested
172   // in a tf_executor.graph region.
173   for (Operation &op : graph.GetBody()) {
174     if (op.getDialect() != executorDialect)
175       return op.emitOpError() << "unallowed inside a tf_executor.graph region";
176     if (isa<GraphOp>(op))
177       return op.emitOpError()
178              << "unallowed directly inside another tf_executor.graph";
179   }
180 
181   Operation &fetch = graph.GetBody().back();
182   if (!isa<FetchOp>(fetch))
183     return fetch.emitOpError()
184            << "invalid tf_executor.graph terminator, fetch expected";
185 
186   // Ensure that the fetch terminator operands matches the graph result type.
187   // All the non-control operands of the fetch operation must match the graph
188   // returned value.
189   if (fetch.getNumOperands() < graph.getNumResults())
190     return fetch.emitOpError() << "does not have enough operands to cover the "
191                                   "graph returned values";
192   for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
193     Value operand = fetch.getOperand(i);
194     // Break out of the loop at the first control operand encountered.
195     const int64_t num_results = graph.getNumResults();
196     if (operand.getType().isa<ControlType>()) {
197       if (i != num_results)
198         return fetch.emitOpError()
199                << "operand #" << i
200                << " is a control type, can't be bound to a graph result";
201       break;
202     }
203     if (i >= num_results)
204       return fetch.emitOpError()
205              << "operand #" << i << " does not have a graph results to bind";
206     if (graph.getResult(i).getType() != operand.getType())
207       return fetch.emitOpError()
208              << "operand #" << i << " type mismatch graph results";
209   }
210   return success();
211 }
212 
Print(GraphOp graph,OpAsmPrinter & p)213 void Print(GraphOp graph, OpAsmPrinter &p) {
214   p << graph.getOperationName();
215   p.printRegion(graph.getOperation()->getRegion(0));
216   p.printOptionalAttrDict(graph.getAttrs());
217 }
218 
ParseGraphOp(OpAsmParser & parser,OperationState & result)219 ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
220   llvm::SMLoc loc = parser.getCurrentLocation();
221 
222   // Parse the body region.
223   Region &body = *result.addRegion();
224   if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
225 
226   // Ensure that the region is well formed: it contains at least a block with
227   // a FetchOp terminator.
228   GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
229 
230   if (!llvm::hasSingleElement(body))
231     return parser.emitError(loc) << "expects a single block region";
232 
233   // Get the results type from the terminator type inside the graph.
234   Operation &fetch = body.back().back();
235   if (!isa<FetchOp>(fetch))
236     return parser.emitError(loc) << "expects a tf_executor.fetch terminator";
237 
238   // The return value of the graph operation are the non-control operands of
239   // the fetch operation.
240   result.types.reserve(fetch.getNumOperands());
241   for (Type type : fetch.getOperandTypes()) {
242     if (type.isa<ControlType>()) break;
243     result.types.push_back(type);
244   }
245 
246   // Parse the optional attribute list.
247   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
248 
249   return success();
250 }
251 
252 }  // anonymous namespace
253 
254 //===----------------------------------------------------------------------===//
255 // tf_executor.fetch
256 //===----------------------------------------------------------------------===//
257 
258 //===----------------------------------------------------------------------===//
259 // tf_executor.island
260 //===----------------------------------------------------------------------===//
261 
GetYield()262 YieldOp IslandOp::GetYield() { return llvm::cast<YieldOp>(GetBody().back()); }
263 
264 // Checks if a tf_executor.island wraps a single operation and the single
265 // operation results are perfectly forwarded to the islands yield.
WrapsSingleOp()266 bool IslandOp::WrapsSingleOp() {
267   auto body = GetBody().without_terminator();
268   if (!hasSingleElement(body)) return false;
269 
270   Operation &wrapped_op = *body.begin();
271   YieldOp yield = GetYield();
272   return wrapped_op.getNumResults() == yield.getNumOperands() &&
273          std::equal(wrapped_op.getResults().begin(),
274                     wrapped_op.getResults().end(), yield.getOperands().begin());
275 }
276 
277 namespace {
278 
Verify(IslandOp island)279 LogicalResult Verify(IslandOp island) {
280   if (!island.GetBody().args_empty())
281     return island.emitOpError() << "expects body without any arguments";
282 
283   Operation &yield = island.GetBody().back();
284   if (!isa<YieldOp>(yield))
285     return yield.emitOpError()
286            << "invalid tf_executor.island terminator, yield expected";
287 
288   // Ensure that the yield terminator operands matches the island results type.
289   int result_count = island.getNumResults() - 1;  // -1 for the control token
290   const int num_operands = yield.getNumOperands();
291   if (num_operands != result_count)
292     return yield.emitOpError()
293            << "has " << yield.getNumOperands()
294            << " operand, but island returns " << result_count;
295   for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
296     if (island.getResult(operand_idx).getType() !=
297         yield.getOperand(operand_idx).getType())
298       return yield.emitOpError()
299              << "operand #" << operand_idx << " type mismatch island results";
300   }
301 
302   // Check that there aren't any control results other than the last one.
303   Type control_type = ControlType::get(island.getContext());
304   for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
305     if (island.getResult(operand_idx).getType() == control_type)
306       return yield.emitOpError()
307              << "unexpected control type for operand #" << operand_idx;
308   }
309   return success();
310 }
311 
Print(IslandOp op,OpAsmPrinter & p)312 void Print(IslandOp op, OpAsmPrinter &p) {
313   p << op.getOperationName();
314   if (op.getNumOperands()) {
315     // These are always control operand, no explicit type needed.
316     p << '(';
317     p.printOperands(op.getOperands());
318     p << ')';
319   }
320 
321   // Check if we can print the short "wraps" form: that is if the island
322   // contains a single operation and the result of this operation are perfectly
323   // forwarded to the yield.
324   if (op.getAttrs().empty() && op.WrapsSingleOp()) {
325     Operation &wrapped_op = op.GetBody().front();
326     YieldOp yield_op = op.GetYield();
327     // The "wraps" syntax only encodes a single location.
328     // In order to correctly round-trip, we can only use this syntax when all
329     // the locations are identical.
330     if (wrapped_op.getLoc() == op.getLoc() &&
331         yield_op.getLoc() == op.getLoc()) {
332       p << " wraps ";
333       p.printGenericOp(&wrapped_op);
334       return;
335     }
336   }
337   p.printRegion(op.getOperation()->getRegion(0));
338   p.printOptionalAttrDict(op.getAttrs());
339 }
340 
ParseIslandOp(OpAsmParser & parser,OperationState & result)341 ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
342   llvm::SMLoc loc = parser.getCurrentLocation();
343   Type control_type = ControlType::get(parser.getBuilder().getContext());
344 
345   // Parse optional argument list (control dependencies only).
346   SmallVector<OpAsmParser::OperandType, 4> op_infos;
347   if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
348     return failure();
349   if (!op_infos.empty()) {
350     SmallVector<Type, 2> types(op_infos.size(), control_type);
351     parser.resolveOperands(op_infos, types, loc, result.operands);
352   }
353 
354   // Parse the body region.
355   Region &body = *result.addRegion();
356 
357   if (succeeded(parser.parseOptionalKeyword("wraps"))) {
358     // If we parse the short version of the island, we have an operation in the
359     // generic form that follows the "wraps" keyword. Parse it inside the region
360     // and forward all of its results as-is to the yield operation.
361     body.push_back(new Block);
362     Block &block = body.back();
363     Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
364     if (!wrapped_op) return failure();
365     OpBuilder builder(parser.getBuilder().getContext());
366     builder.setInsertionPointToEnd(&block);
367     builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
368     result.location = wrapped_op->getLoc();
369   } else if (parser.parseRegion(body, llvm::None, llvm::None)) {
370     return failure();
371   }
372 
373   IslandOp::ensureTerminator(body, parser.getBuilder(), result.location);
374 
375   // Get the results type for the island from the terminator operands.
376   Operation &yield = body.back().back();
377   result.types.reserve(yield.getNumOperands() + 1);
378   result.types.append(yield.operand_type_begin(), yield.operand_type_end());
379   result.types.push_back(control_type);
380 
381   // Parse the optional attribute list.
382   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
383   return success();
384 }
385 
386 }  // anonymous namespace
387 
388 //===----------------------------------------------------------------------===//
389 // tf_executor.yield
390 //===----------------------------------------------------------------------===//
391 
392 //===----------------------------------------------------------------------===//
393 // tf_executor.Switch
394 //===----------------------------------------------------------------------===//
395 
396 namespace {
397 
ParseSwitchOp(OpAsmParser & parser,OperationState & result)398 ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
399   SmallVector<OpAsmParser::OperandType, 2> op_infos;
400   SmallVector<Type, 1> types;
401   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
402     return failure();
403   if (types.size() != 1)
404     return parser.emitError(parser.getNameLoc())
405            << " expects only a single data type";
406 
407   // Support parsing either a functional type (in which case all the types are
408   // fully qualified) or a short form with a single type (in which case the data
409   // input and the outputs are all using this type and predicate is tensor<i1>
410   // type).
411   if (types.front().isa<FunctionType>()) {
412     FunctionType type = types.front().cast<FunctionType>();
413     if (type.getNumInputs() < 2)
414       return parser.emitError(parser.getNameLoc())
415              << " expects a single data type and a predicate";
416     result.types.assign(type.getResults().begin(), type.getResults().end());
417     types.assign(type.getInputs().begin(), type.getInputs().end());
418   } else {
419     if (op_infos.size() < 2)
420       return parser.emitError(parser.getNameLoc())
421              << " expects a single data type and a predicate";
422     Type control_type = ControlType::get(parser.getBuilder().getContext());
423     result.types.append(2, types[0]);
424     result.types.push_back(control_type);
425     Type i1_type = parser.getBuilder().getI1Type();
426     RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
427     types.push_back(predicate_type);
428     types.append(op_infos.size() - 2, control_type);
429   }
430 
431   llvm::SMLoc loc = parser.getCurrentLocation();
432   if (parser.resolveOperands(op_infos, types, loc, result.operands))
433     return failure();
434 
435   return parser.parseOptionalAttrDict(result.attributes);
436 }
437 
Print(SwitchOp switch_op,OpAsmPrinter & p)438 void Print(SwitchOp switch_op, OpAsmPrinter &p) {
439   p << switch_op.getOperationName() << ' ';
440   p.printOperands(switch_op.getOperands());
441   Type data_operand_ty = switch_op.data().getType();
442   // If the types aren't perfectly matching, print the functional type syntax
443   // else print the shorter single type.
444   p << " : ";
445   if (switch_op.trueOutput().getType() != data_operand_ty ||
446       switch_op.falseOutput().getType() != data_operand_ty ||
447       switch_op.predicate().getType().isa<UnrankedTensorType>()) {
448     p.printFunctionalType(switch_op.getOperation());
449   } else {
450     p << switch_op.getType(0);
451   }
452   p.printOptionalAttrDict(switch_op.getAttrs());
453 }
454 
455 }  // anonymous namespace
456 
457 //===----------------------------------------------------------------------===//
458 // tf_executor.SwitchN
459 //===----------------------------------------------------------------------===//
460 
461 namespace {
462 
Verify(SwitchNOp switchn)463 LogicalResult Verify(SwitchNOp switchn) {
464   IntegerAttr num_outs = switchn->getAttrOfType<IntegerAttr>("num_outs");
465   if (!num_outs)
466     return switchn.emitOpError() << "expects a `num_outs` integer attribute";
467 
468   // Expect num_outs results + 1 control output.
469   if (switchn.getNumResults() != num_outs.getInt() + 1)
470     return switchn.emitOpError()
471            << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
472            << (switchn.getNumResults() - 1);
473 
474   // Check that operand can be broadcasted to each output type.
475   auto operand0_type = switchn.getOperand(0).getType();
476   TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
477   if (!operand0_tensor_type) {
478     return switchn.emitOpError()
479            << "expects data operand to have tensor type but got "
480            << operand0_type;
481   }
482   for (Type output_type : switchn.getResultTypes()) {
483     if (output_type.isa<ControlType>()) break;
484 
485     TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
486     if (!output_tensor_type) {
487       return switchn.emitOpError()
488              << "expects outputs to have tensor type but got " << output_type;
489     }
490 
491     // If the output type is a ref type, then the operand type should also be of
492     // the same ref type. However, if the output type is a non-ref type T, then
493     // the operand can be tensor of type T or T_REF.
494     bool is_output_ref =
495         output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
496     if (is_output_ref &&
497         !operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
498       return switchn.emitOpError()
499              << "expects same operand and output element type but got "
500              << operand0_tensor_type << " vs " << output_tensor_type;
501     }
502     Type broadcasted_type = OpTrait::util::getBroadcastedType(
503         TF::DropRefAndSubTypes(operand0_tensor_type),
504         TF::DropRefAndSubTypes(output_tensor_type));
505     if (!broadcasted_type) {
506       return switchn.emitOpError()
507              << "expects data operand to be broadcastable with all output types"
508              << " but got " << operand0_tensor_type << " vs "
509              << output_tensor_type;
510     }
511   }
512   return success();
513 }
514 
Print(SwitchNOp switchn,OpAsmPrinter & p)515 void Print(SwitchNOp switchn, OpAsmPrinter &p) {
516   p << switchn.getOperationName() << ' ';
517   auto operands = switchn.getOperands();
518   // Print the 2 data operands.
519   p.printOperands(operands.begin(), std::next(operands.begin(), 2));
520   p << " of " << (switchn.getNumResults() - 1);
521   // print control dependencies if any
522   if (!llvm::empty(switchn.controlInputs())) {
523     p << " (";
524     p.printOperands(switchn.controlInputs());
525     p << ")";
526   }
527   p << " : " << switchn.getType(0);
528   p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
529 }
530 
ParseSwitchNOp(OpAsmParser & parser,OperationState & result)531 ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
532   // Parsing:
533   //       %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
534   // Where the first operand is the data to replicate, the second is an i32
535   // indicating which output to populate, followed by the keyword `of` and the
536   // number of outputs (+1 for the control token).
537   SmallVector<OpAsmParser::OperandType, 2> op_infos;
538   SmallVector<Type, 1> types;
539   llvm::SMLoc loc = parser.getCurrentLocation();
540   IntegerAttr num_outs;
541   Type i64_type = parser.getBuilder().getIntegerType(64);
542   if (parser.parseOperandList(op_infos, 2) || parser.parseKeyword("of") ||
543       parser.parseAttribute(num_outs, i64_type, "num_outs",
544                             result.attributes) ||
545       parser.parseOperandList(op_infos,
546                               OpAsmParser::Delimiter::OptionalParen) ||
547       parser.parseColonTypeList(types))
548     return failure();
549   if (types.size() != 1)
550     return parser.emitError(parser.getNameLoc())
551            << " expects only a single data type";
552 
553   if (num_outs.getInt() <= 0)
554     return parser.emitError(parser.getNameLoc())
555            << " expects a positive number of outputs";
556 
557   // `types` already contains the type for the data, add an i32 for the
558   // output_index, and then the optional control inputs.
559   auto builder = parser.getBuilder();
560   types.push_back(RankedTensorType::get({}, builder.getIntegerType(32)));
561   Type control_type = ControlType::get(builder.getContext());
562   types.append(op_infos.size() - 2, control_type);
563 
564   if (parser.resolveOperands(op_infos, types, loc, result.operands))
565     return failure();
566 
567   // Output result types is a replication `num_outs` times the data input type.
568   result.types.append(num_outs.getInt(), types[0]);
569   result.types.push_back(control_type);
570 
571   return parser.parseOptionalAttrDict(result.attributes);
572 }
573 
574 }  // anonymous namespace
575 
576 //===----------------------------------------------------------------------===//
577 // tf_executor.Merge
578 //===----------------------------------------------------------------------===//
579 
580 namespace {
581 
Verify(MergeOp merge)582 LogicalResult Verify(MergeOp merge) {
583   if (!merge.getNumOperands())
584     return merge.emitOpError() << "expects at least one operand";
585 
586   Type data_type = merge.getOperand(0).getType();
587   if (data_type.isa<ControlType>())
588     return merge.emitOpError() << "expects a non-control input";
589 
590   // Check that each operand can be individually broadcasted to the output type.
591   Type output_type = merge.output().getType();
592   TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
593   if (!output_tensor_ty) {
594     return merge.emitOpError()
595            << "expects output to have tensor type but got " << output_type;
596   }
597   bool is_output_ref =
598       output_tensor_ty.getElementType().isa<TF::TensorFlowRefType>();
599   for (Type operand_type : merge.getOperandTypes()) {
600     if (operand_type.isa<ControlType>()) break;
601 
602     // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
603     // constraint.
604     TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
605     if (!operand_tensor_ty)
606       return merge.emitOpError()
607              << "expects data operands to have tensor type but got "
608              << operand_type;
609 
610     // If output type is a ref type then all operand types should also be of the
611     // same ref type. However, if the output type is a non-ref type T, operands
612     // can be tensor of type T or T_REF.
613     if (is_output_ref &&
614         !operand_tensor_ty.getElementType().isa<TF::TensorFlowRefType>()) {
615       return merge.emitOpError()
616              << "expects same operand and output element type but got "
617              << operand_tensor_ty << " vs " << output_tensor_ty;
618     }
619     Type broadcasted_type = OpTrait::util::getBroadcastedType(
620         TF::DropRefAndSubTypes(output_tensor_ty),
621         TF::DropRefAndSubTypes(operand_tensor_ty));
622     if (!broadcasted_type)
623       return merge.emitOpError()
624              << "expects all operands to be broadcastable with output type"
625              << " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
626   }
627   return success();
628 }
629 
Print(MergeOp merge,OpAsmPrinter & p)630 void Print(MergeOp merge, OpAsmPrinter &p) {
631   // Use short form only when there are exactly two data operands and their
632   // type matches the output type. Otherwise, use the generic printer.
633   bool use_short_form = true;
634   int num_data_operands = 0;
635 
636   Type output_type = merge.output().getType();
637   for (Type operand_type : merge.getOperandTypes()) {
638     if (operand_type.isa<ControlType>()) break;
639     num_data_operands++;
640 
641     if (operand_type != output_type) {
642       use_short_form = false;
643       break;
644     }
645   }
646 
647   p << merge.getOperationName() << ' ';
648   p.printOperands(merge.getOperands());
649 
650   // Print the type signature of the operation.
651   p << " : ";
652   if (!use_short_form || num_data_operands != 2) {
653     p.printFunctionalType(merge.getOperation());
654   } else {
655     p << output_type;
656   }
657 
658   p.printOptionalAttrDict(merge.getAttrs());
659 }
660 
ParseMergeOp(OpAsmParser & parser,OperationState & result)661 ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) {
662   SmallVector<OpAsmParser::OperandType, 2> op_infos;
663   SmallVector<Type, 1> types;
664   llvm::SMLoc loc = parser.getCurrentLocation();
665   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
666     return failure();
667   if (types.size() != 1)
668     return parser.emitError(parser.getNameLoc())
669            << " expects only a single data type";
670 
671   // Support parsing either a functional type (in which case all the types are
672   // fully qualified) or a short form with a single type (in which case the data
673   // inputs and the output are all using this type).
674   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
675     result.types.assign(type.getResults().begin(), type.getResults().end());
676     types.assign(type.getInputs().begin(), type.getInputs().end());
677   } else {
678     // In case of the short form, use the parsed type for both the operands and
679     // the remaining operands are expected to be control inputs.
680     types.push_back(Type(types.front()));
681     Type control_type = ControlType::get(parser.getBuilder().getContext());
682     types.append(op_infos.size() - 2, control_type);
683 
684     RankedTensorType i32_tensor =
685         RankedTensorType::get({}, parser.getBuilder().getIntegerType(32));
686     result.types = {types.front(), i32_tensor, control_type};
687   }
688 
689   if (parser.resolveOperands(op_infos, types, loc, result.operands))
690     return failure();
691 
692   return parser.parseOptionalAttrDict(result.attributes);
693 }
694 
695 }  // anonymous namespace
696 
697 //===----------------------------------------------------------------------===//
698 // tf_executor.Enter
699 //===----------------------------------------------------------------------===//
700 
701 namespace {
702 
703 // Default number for the parallel_iterations attributes on Enter nodes.
704 constexpr int kDefaultParallelIterations = 10;
705 
Print(EnterOp enter,OpAsmPrinter & p)706 void Print(EnterOp enter, OpAsmPrinter &p) {
707   p << enter.getOperationName() << ' ';
708   p.printOperands(enter.getOperands());
709 
710   p << " frame \"";
711   printEscapedString(enter.frame_name(), p.getStream());
712   p << "\"";
713   if (enter.parallel_iterations() != kDefaultParallelIterations)
714     p << " parallel_iterations " << enter.parallel_iterations();
715   if (enter.is_constant()) p << " constant ";
716 
717   // If the types aren't perfectly matching, print the functional type syntax
718   // else print the shorter single type.
719   p << " : ";
720   if (enter.data().getType() != enter.output().getType()) {
721     p.printFunctionalType(enter.getOperation());
722   } else {
723     p << enter.getType(0);
724   }
725 
726   p.printOptionalAttrDict(enter.getAttrs(),
727                           {"frame_name", "parallel_iterations", "is_constant"});
728 }
729 
ParseEnterOp(OpAsmParser & parser,OperationState & result)730 ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
731   SmallVector<OpAsmParser::OperandType, 2> op_infos;
732   llvm::SMLoc loc = parser.getCurrentLocation();
733   MLIRContext *context = parser.getBuilder().getContext();
734   if (parser.parseOperandList(op_infos)) return failure();
735   if (op_infos.empty())
736     return parser.emitError(loc) << " expects at least one data operand";
737 
738   Attribute frame;
739   if (parser.parseKeyword("frame") ||
740       parser.parseAttribute(frame, NoneType::get(context), "frame_name",
741                             result.attributes))
742     return failure();
743 
744   Type i64 = parser.getBuilder().getIntegerType(64);
745   if (parser.parseOptionalKeyword("parallel_iterations")) {
746     result.addAttribute("parallel_iterations",
747                         IntegerAttr::get(i64, kDefaultParallelIterations));
748   } else {
749     IntegerAttr parallel_iterations;
750     if (parser.parseAttribute(parallel_iterations, i64, "parallel_iterations",
751                               result.attributes))
752       return failure();
753   }
754   bool has_constant = succeeded(parser.parseOptionalKeyword("constant"));
755   result.addAttribute("is_constant", BoolAttr::get(context, has_constant));
756 
757   SmallVector<Type, 1> types;
758   if (parser.parseColonTypeList(types)) return failure();
759   if (types.size() != 1)
760     return parser.emitError(loc) << " expects only a single data type";
761 
762   // Support parsing either a functional type (in which case all the types are
763   // fully qualified) or a short form with a single type (in which case the data
764   // input and the outputs are all using this type).
765   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
766     // One data input, and any number of control inputs.
767     if (type.getNumInputs() >= 1) {
768       result.types.assign(type.getResults().begin(), type.getResults().end());
769       types.assign(type.getInputs().begin(), type.getInputs().end());
770     } else {
771       return parser.emitError(parser.getNameLoc()) << " expects a data input";
772     }
773   } else {
774     Type control_type = ControlType::get(context);
775     types.append(op_infos.size() - 1, control_type);
776     result.addTypes({types.front(), control_type});
777   }
778 
779   // Extra operands are expected to be control inputs.
780 
781   if (parser.resolveOperands(op_infos, types, loc, result.operands))
782     return failure();
783 
784   return parser.parseOptionalAttrDict(result.attributes);
785 }
786 
787 }  // anonymous namespace
788 
789 //===----------------------------------------------------------------------===//
790 // tf_executor.NextIteration.Source
791 //===----------------------------------------------------------------------===//
792 
793 namespace {
794 
Verify(NextIterationSourceOp source)795 LogicalResult Verify(NextIterationSourceOp source) {
796   Value token = source.token();
797   if (!token.hasOneUse())
798     return source.emitOpError() << "expects a single user for produced token";
799   if (!isa<NextIterationSinkOp>(*token.user_begin()))
800     return source.emitOpError() << "token should be consumed by a sink op";
801   return success();
802 }
803 
804 }  // anonymous namespace
805 
806 //===----------------------------------------------------------------------===//
807 // tf_executor.NextIteration.Sink
808 //===----------------------------------------------------------------------===//
809 
810 namespace {
811 
Verify(NextIterationSinkOp sink)812 LogicalResult Verify(NextIterationSinkOp sink) {
813   Value token = sink.token();
814   Operation *definingOp = token.getDefiningOp();
815   if (!definingOp)
816     return sink.emitOpError() << "expects a token directly produced by a "
817                                  "tf_executor.NextIteration.Source op: ";
818   auto source = dyn_cast<NextIterationSourceOp>(definingOp);
819   if (!source)
820     return sink.emitOpError() << "expects a token produced by a "
821                                  "tf_executor.NextIteration.Source op: ";
822   if (source.output().getType() != sink.input().getType())
823     return sink.emitOpError()
824            << "input type " << sink.input().getType()
825            << " mismatch the tf_executor.NextIteration.Source output type: "
826            << source.output().getType();
827   return success();
828 }
829 
830 }  // anonymous namespace
831 
GetSource()832 NextIterationSourceOp NextIterationSinkOp::GetSource() {
833   return cast<NextIterationSourceOp>(token().getDefiningOp());
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // tf_executor.Exit
838 //===----------------------------------------------------------------------===//
839 
840 namespace {
841 
Print(ExitOp exit,OpAsmPrinter & p)842 void Print(ExitOp exit, OpAsmPrinter &p) {
843   p << exit.getOperationName() << ' ';
844   p.printOperands(exit.getOperands());
845   p << " : " << exit.getType(0);
846   p.printOptionalAttrDict(exit.getAttrs());
847 }
848 
ParseExitOp(OpAsmParser & parser,OperationState & result)849 ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
850   SmallVector<OpAsmParser::OperandType, 2> op_infos;
851   SmallVector<Type, 1> types;
852 
853   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
854     return failure();
855 
856   llvm::SMLoc loc = parser.getCurrentLocation();
857   Type control_type = ControlType::get(parser.getBuilder().getContext());
858   types.append(op_infos.size() - 1, control_type);
859   if (parser.resolveOperands(op_infos, types, loc, result.operands))
860     return failure();
861 
862   result.addTypes({types.front(), control_type});
863   return parser.parseOptionalAttrDict(result.attributes);
864 }
865 
866 }  // anonymous namespace
867 
868 //===----------------------------------------------------------------------===//
869 // tf_executor.ControlTrigger
870 //===----------------------------------------------------------------------===//
871 
872 //===----------------------------------------------------------------------===//
873 // tf_executor.LoopCond
874 //===----------------------------------------------------------------------===//
875 
876 namespace {
877 
Print(LoopCondOp loop_cond,OpAsmPrinter & p)878 void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
879   p << loop_cond.getOperationName() << ' ';
880   p.printOperands(loop_cond.getOperands());
881 
882   // If the types aren't matching (broadcast), print the functional type syntax.
883   if (loop_cond.input().getType() != loop_cond.output().getType()) {
884     p << " : ";
885     p.printFunctionalType(loop_cond.getOperation());
886   } else {
887     p << " : " << loop_cond.input().getType();
888   }
889 
890   p.printOptionalAttrDict(loop_cond.getAttrs());
891 }
892 
ParseLoopCondOp(OpAsmParser & parser,OperationState & result)893 ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) {
894   SmallVector<OpAsmParser::OperandType, 2> op_infos;
895 
896   if (parser.parseOperandList(op_infos)) return failure();
897   if (op_infos.empty())
898     return parser.emitError(parser.getNameLoc())
899            << "expects at least one operand";
900 
901   SmallVector<Type, 1> types;
902   if (parser.parseColonTypeList(types)) return failure();
903 
904   // Support parsing either a functional type (in which case all the types are
905   // fully qualified) or a short form with a single type (in which case the data
906   // input and the outputs are all using this type).
907   Type control_type = ControlType::get(parser.getBuilder().getContext());
908   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
909     if (llvm::count_if(type.getInputs(),
910                        [=](Type type) { return type != control_type; }) != 1)
911       return parser.emitError(parser.getNameLoc())
912              << " expects a single data type";
913     result.types.assign(type.getResults().begin(), type.getResults().end());
914     types.assign(type.getInputs().begin(), type.getInputs().end());
915   } else {
916     if (types.size() != 1)
917       return parser.emitError(parser.getNameLoc())
918              << " expects a single data type";
919     types.append(op_infos.size() - 1, control_type);
920     result.addTypes({types.front(), control_type});
921   }
922 
923   llvm::SMLoc loc = parser.getCurrentLocation();
924   if (parser.resolveOperands(op_infos, types, loc, result.operands))
925     return failure();
926 
927   return parser.parseOptionalAttrDict(result.attributes);
928 }
929 
930 }  // namespace
931 
932 //===----------------------------------------------------------------------===//
933 // Canonicalization patterns
934 //===----------------------------------------------------------------------===//
935 
936 // TODO(lyandy): Add canonicalization for dedupping control inputs.
937 
938 //===----------------------------------------------------------------------===//
939 // tf_executor.graph
940 //===----------------------------------------------------------------------===//
941 
942 namespace {
943 // Finds in a block if the op of type `InnerOpT` is the first operation and
944 // optionally followed by a terminator.
945 template <typename InnerOpT>
HasSingleOpInBlock(Block * block)946 bool HasSingleOpInBlock(Block *block) {
947   if (block->empty()) return false;
948   if (!llvm::isa<InnerOpT>(block->front())) return false;
949   // Either InnerOpT is the only instruction in the block, or there is a
950   // possible terminator.
951   return std::next(block->begin()) == block->end() ||
952          std::next(block->begin(), 2) == block->end();
953 }
954 
955 // This pattern matches GraphOps with only one FetchOp (empty) and remaps the
956 // results of the GraphOp to the operands of the FetchOp.
957 struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
958   using OpRewritePattern<GraphOp>::OpRewritePattern;
959 
matchAndRewritemlir::tf_executor::__anonecac450a0e11::DropEmptyGraph960   LogicalResult matchAndRewrite(GraphOp op,
961                                 PatternRewriter &rewriter) const override {
962     Block &block = op.GetBody();
963     // Check if graph only has one fetch.
964     if (&block.front() != &block.back()) return failure();
965 
966     // Map graph results to fetch operands.
967     rewriter.replaceOp(op, op.GetFetch().fetches());
968 
969     return success();
970   }
971 };
972 
973 // This pattern matches GraphOps with only one island, pulls out all inner ops
974 // of the island to the block containing the GraphOp, and then removes the
975 // GraphOp.
976 struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
977   using OpRewritePattern<GraphOp>::OpRewritePattern;
978 
matchAndRewritemlir::tf_executor::__anonecac450a0e11::HoistInnerOpsSingleIslandGraph979   LogicalResult matchAndRewrite(GraphOp op,
980                                 PatternRewriter &rewriter) const override {
981     Block &block = op.GetBody();
982     // Check if graph only has one island.
983     if (!HasSingleOpInBlock<IslandOp>(&block)) return failure();
984 
985     FetchOp fetch_op = op.GetFetch();
986     auto island_op = llvm::cast<IslandOp>(block.front());
987     YieldOp yield_op = island_op.GetYield();
988 
989     // Map graph results to inner ops results of single island.
990     llvm::SmallVector<Value, 8> new_rets;
991     for (Value operand : fetch_op.fetches()) {
992       // Control results should not be propagated out.
993       if (operand.getType().isa<ControlType>()) break;
994 
995       if (operand.getDefiningOp() != island_op) {
996         // Operand is not from island, simply propagate it out.
997         new_rets.push_back(operand);
998       } else {
999         // Lookup yield operand in island for inner op result.
1000         auto result = operand.cast<OpResult>();
1001         new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
1002       }
1003     }
1004 
1005     // Move inner ops from island to block containing graph.
1006     auto &island_body = island_op.GetBody().getOperations();
1007     Operation *operation = op.getOperation();
1008     operation->getBlock()->getOperations().splice(
1009         operation->getIterator(), island_body, island_body.begin(),
1010         std::prev(island_body.end()));
1011     rewriter.replaceOp(op, new_rets);
1012 
1013     return success();
1014   }
1015 };
1016 }  // anonymous namespace
1017 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1018 void GraphOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1019                                           MLIRContext *context) {
1020   results.insert<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context);
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // tf_executor.island
1025 //===----------------------------------------------------------------------===//
1026 
1027 namespace {
1028 // This pattern matches and removes IslandOps with no inner ops, no control
1029 // operands and no data results. Control result users will have their relevant
1030 // operands removed.
1031 struct DropEmptyIslandNoOperandNoDataResult
1032     : public OpRewritePattern<IslandOp> {
1033   using OpRewritePattern<IslandOp>::OpRewritePattern;
1034 
matchAndRewritemlir::tf_executor::__anonecac450a0f11::DropEmptyIslandNoOperandNoDataResult1035   LogicalResult matchAndRewrite(IslandOp op,
1036                                 PatternRewriter &rewriter) const override {
1037     if (op.getNumOperands() != 0 || op.getNumResults() != 1 ||
1038         !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1039       return failure();
1040 
1041     for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1042       use.getOwner()->eraseOperand(use.getOperandNumber());
1043 
1044     rewriter.eraseOp(op);
1045 
1046     return success();
1047   }
1048 };
1049 
1050 // This pattern matches and removes IslandOps with no inner ops, no control
1051 // operands, one data result and no control result user. The single data result
1052 // (from YieldOps first operand) is forwarded to the IslandOp single data result
1053 // users.
1054 struct DropEmptyIslandNoOperandOneDataResult
1055     : public OpRewritePattern<IslandOp> {
1056   using OpRewritePattern<IslandOp>::OpRewritePattern;
1057 
matchAndRewritemlir::tf_executor::__anonecac450a0f11::DropEmptyIslandNoOperandOneDataResult1058   LogicalResult matchAndRewrite(IslandOp op,
1059                                 PatternRewriter &rewriter) const override {
1060     if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
1061         !op.control().use_empty() ||
1062         !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1063       return failure();
1064 
1065     rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr});
1066 
1067     return success();
1068   }
1069 };
1070 
1071 // TODO(lyandy): Add canonicalization for empty IslandOps with more than one
1072 // control operand and no data results.
1073 
1074 }  // anonymous namespace
1075 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1076 void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1077                                            MLIRContext *context) {
1078   results.insert<DropEmptyIslandNoOperandNoDataResult,
1079                  DropEmptyIslandNoOperandOneDataResult>(context);
1080 }
1081 
1082 //===----------------------------------------------------------------------===//
1083 // tf_executor.ControlTrigger
1084 //===----------------------------------------------------------------------===//
1085 
1086 namespace {
1087 // This pattern matches and removes ControlTriggerOps with no control operands.
1088 // Control result users will have their relevant operands removed.
1089 struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
1090   using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
1091 
matchAndRewritemlir::tf_executor::__anonecac450a1011::DropEmptyControlTrigger1092   LogicalResult matchAndRewrite(ControlTriggerOp op,
1093                                 PatternRewriter &rewriter) const override {
1094     if (op.getNumOperands() != 0) return failure();
1095 
1096     for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1097       use.getOwner()->eraseOperand(use.getOperandNumber());
1098 
1099     rewriter.eraseOp(op);
1100 
1101     return success();
1102   }
1103 };
1104 }  // anonymous namespace
1105 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1106 void ControlTriggerOp::getCanonicalizationPatterns(
1107     OwningRewritePatternList &results, MLIRContext *context) {
1108   results.insert<DropEmptyControlTrigger>(context);
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // Folders
1113 //===----------------------------------------------------------------------===//
1114 
1115 //===----------------------------------------------------------------------===//
1116 // tf_executor.island
1117 //===----------------------------------------------------------------------===//
1118 
fold(llvm::ArrayRef<Attribute> operands,llvm::SmallVectorImpl<OpFoldResult> & results)1119 LogicalResult IslandOp::fold(llvm::ArrayRef<Attribute> operands,
1120                              llvm::SmallVectorImpl<OpFoldResult> &results) {
1121   // This folds IslandOps with no inner ops, one control operand and no data
1122   // results. The single control operand is forwarded to the IslandOp control
1123   // result users.
1124   if (getNumOperands() != 1 || getNumResults() != 1 ||
1125       !HasSingleOpInBlock<YieldOp>(&GetBody()))
1126     return failure();
1127 
1128   results.emplace_back(getOperand(0));
1129 
1130   return success();
1131 }
1132 
1133 }  // namespace tf_executor
1134 }  // namespace mlir
1135 
1136 //===----------------------------------------------------------------------===//
1137 // TableGen'd op method definitions
1138 //===----------------------------------------------------------------------===//
1139 
1140 #define GET_OP_CLASSES
1141 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
1142