1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
40 #include "mlir/IR/Types.h"  // from @llvm-project
41 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
42 #include "mlir/IR/Value.h"  // from @llvm-project
43 #include "mlir/Support/LLVM.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48 
49 namespace mlir {
50 namespace tf_device {
51 
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58   using DialectInlinerInterface::DialectInlinerInterface;
59 
60   //===--------------------------------------------------------------------===//
61   // Analysis Hooks
62   //===--------------------------------------------------------------------===//
63 
64   // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface65   bool isLegalToInline(Operation* call, Operation* callable,
66                        bool wouldBeCloned) const final {
67     return true;
68   }
69   // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface70   bool isLegalToInline(Operation*, Region*, bool,
71                        BlockAndValueMapping&) const final {
72     // For now, enable inlining all operations.
73     return true;
74   }
75 
76   //===--------------------------------------------------------------------===//
77   // Transformation Hooks
78   //===--------------------------------------------------------------------===//
79 
80   // Attempts to materialize a conversion for a type mismatch between a call
81   // from this dialect, and a callable region. This method should generate an
82   // operation that takes 'input' as the only operand, and produces a single
83   // result of 'resultType'. If a conversion can not be generated, nullptr
84   // should be returned.
85   // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface86   Operation* materializeCallConversion(OpBuilder& builder, Value input,
87                                        Type result_type,
88                                        Location conversion_loc) const final {
89     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
90       return nullptr;
91     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
92                                       /*truncate=*/builder.getBoolAttr(false));
93   }
94 };
95 
96 // Checks if a block wraps a single operation and the single operation results
97 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)98 bool BlockWrapsSingleOp(Block* block) {
99   auto body = block->without_terminator();
100   if (!hasSingleElement(body)) return false;
101 
102   Operation& wrapped_op = *body.begin();
103   Operation* terminator = block->getTerminator();
104   return wrapped_op.getNumResults() == terminator->getNumOperands() &&
105          std::equal(wrapped_op.getResults().begin(),
106                     wrapped_op.getResults().end(),
107                     terminator->getOperands().begin());
108 }
109 }  // end anonymous namespace
110 
TensorFlowDeviceDialect(MLIRContext * context)111 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
112     : Dialect(/*name=*/"tf_device", context,
113               TypeID::get<TensorFlowDeviceDialect>()) {
114   addOperations<
115 #define GET_OP_LIST
116 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
117       >();
118 
119   addInterfaces<TFInlinerInterface>();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // tf_device.launch
124 //===----------------------------------------------------------------------===//
125 
126 // Checks if a tf_device.launch wraps a single operation and the single
127 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()128 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
129 
130 //===----------------------------------------------------------------------===//
131 // tf_device.parallel_execute
132 //===----------------------------------------------------------------------===//
133 
134 namespace {
135 
Verify(ParallelExecuteOp op)136 LogicalResult Verify(ParallelExecuteOp op) {
137   const auto& regions = op.getOperation()->getRegions();
138   if (regions.size() < 2) {
139     return op.emitOpError() << "must have at least two regions.";
140   }
141 
142   int output_index = 0;
143   for (auto& region_and_index : llvm::enumerate(regions)) {
144     auto& region = region_and_index.value();
145     auto* region_terminator = region.front().getTerminator();
146 
147     // Check that output types of regions match return operand types.
148     for (auto result_type : region_terminator->getOperandTypes()) {
149       if (result_type !=
150           op.getOperation()->getResult(output_index++).getType()) {
151         return op.emitOpError() << "output types must be a concatenated "
152                                 << "list of output types for each regions.";
153       }
154     }
155   }
156 
157   // Check that total number of outputs from regions match the output types of
158   // the parallel_execute op.
159   const int num_output_types = op.getOperation()->getNumResults();
160   if (num_output_types != output_index) {
161     return op.emitOpError()
162            << "number of output types (" << num_output_types << ") "
163            << "must match the total number of outputs from all "
164            << "regions (" << output_index << ").";
165   }
166 
167   return success();
168 }
169 
170 }  // namespace
171 
172 // static
build(OpBuilder & builder,OperationState & state,int num_regions,llvm::ArrayRef<Type> output_types)173 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
174                               int num_regions,
175                               llvm::ArrayRef<Type> output_types) {
176   DCHECK_GE(num_regions, 2);
177   for (int i = 0; i < num_regions; ++i) {
178     Region* region = state.addRegion();
179     region->push_back(new Block);
180   }
181   state.addTypes(output_types);
182 }
183 
GetRegionBlockWithIndex(unsigned index)184 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
185   return getOperation()->getRegion(index).front();
186 }
187 
GetRegionOutputs(unsigned region_index)188 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
189     unsigned region_index) {
190   int num_region_results =
191       GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
192 
193   int return_value_offset = 0;
194   for (int region_id = 0; region_id < region_index; ++region_id)
195     return_value_offset +=
196         GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
197 
198   Operation::result_range region_results(getOperation(),
199                                          /*startIndex=*/return_value_offset,
200                                          /*count=*/num_region_results);
201   return region_results;
202 }
203 
RegionWrapsSingleOp(unsigned index)204 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
205   return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // tf_device.replicate
210 //===----------------------------------------------------------------------===//
211 
212 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)213 ParseResult ParseReplicateOpOperands(
214     OpAsmParser* parser, OperationState* state,
215     llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
216         replicated_inputs,
217     llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
218     llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
219     llvm::SmallVectorImpl<Type>* region_arg_types) {
220   // No operands or empty operand list.
221   bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
222   if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
223     return success();
224 
225   // Parse comma separated operands of the following format:
226   //   replicated_input
227   //     [%a, ...] as %block_arg0: type
228   //   packed_input
229   //     %b as %block_arg1: type
230   //
231   // Replicated inputs are placed before packed inputs when forming the op.
232   llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
233   llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
234   llvm::SmallVector<Type, 8> replicated_region_arg_types;
235   llvm::SmallVector<Type, 8> packed_region_arg_types;
236   do {
237     OpAsmParser::OperandType operand_type;
238     if (parser->parseOptionalOperand(operand_type).hasValue()) {
239       packed_inputs->emplace_back(operand_type);
240       if (parser->parseKeyword("as",
241                                " between packed input and block argument") ||
242           parser->parseRegionArgument(packed_region_args.emplace_back()) ||
243           parser->parseColonType(packed_region_arg_types.emplace_back()))
244         return failure();
245     } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
246                                         OpAsmParser::Delimiter::Square) ||
247                parser->parseKeyword(
248                    "as", " between replicated inputs and block argument") ||
249                parser->parseRegionArgument(
250                    replicated_region_args.emplace_back()) ||
251                parser->parseColonType(
252                    replicated_region_arg_types.emplace_back())) {
253       return failure();
254     }
255   } while (succeeded(parser->parseOptionalComma()));
256 
257   region_args->reserve(replicated_region_args.size() +
258                        packed_region_args.size());
259   region_args->append(replicated_region_args.begin(),
260                       replicated_region_args.end());
261   region_args->append(packed_region_args.begin(), packed_region_args.end());
262 
263   region_arg_types->reserve(replicated_region_arg_types.size() +
264                             packed_region_arg_types.size());
265   region_arg_types->append(replicated_region_arg_types.begin(),
266                            replicated_region_arg_types.end());
267   region_arg_types->append(packed_region_arg_types.begin(),
268                            packed_region_arg_types.end());
269 
270   // Parse remaining `)` surrounding operands.
271   return parser->parseRParen();
272 }
273 
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)274 ParseResult SetReplicateOpOperands(
275     llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
276     llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
277         replicated_inputs,
278     llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
279     llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
280   for (const auto& attr : state->attributes)
281     if (attr.first.strref() == "n")
282       if (auto n_attr = attr.second.dyn_cast<IntegerAttr>())
283         *n = n_attr.getInt();
284 
285   if (*n < 2)
286     return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
287 
288   if (replicated_inputs.empty() && packed_inputs.empty()) return success();
289 
290   for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
291     const int32_t idx = replicated_input_and_idx.index();
292     const auto& replicated_input = replicated_input_and_idx.value();
293     // Check if replicated input matches `n`.
294     if (replicated_input.size() != *n)
295       return parser->emitError(loc)
296              << "expects number of operands for replicated input " << idx
297              << " to be 'n' (" << *n << "), got " << replicated_input.size();
298 
299     // Resolve replicated input and block argument type.
300     if (parser->resolveOperands(replicated_input, region_arg_types[idx],
301                                 state->operands))
302       return failure();
303   }
304 
305   const int32_t num_replicated_block_args = replicated_inputs.size();
306   for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
307     const int32_t idx = packed_input_and_idx.index();
308     const auto& packed_input = packed_input_and_idx.value();
309 
310     // Resolve packed input and block argument type.
311     if (parser->resolveOperand(
312             packed_input, region_arg_types[idx + num_replicated_block_args],
313             state->operands))
314       return failure();
315   }
316 
317   return success();
318 }
319 
320 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
321 
ParseReplicateOp(OpAsmParser * parser,OperationState * state)322 ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
323   llvm::SMLoc loc = parser->getCurrentLocation();
324 
325   // Parse operands, attributes, and region of op.
326   llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
327       replicated_inputs;
328   llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
329   llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
330   llvm::SmallVector<Type, 8> region_arg_types;
331   int32_t n = 0;
332   Region& body = *state->addRegion();
333   if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
334                                &packed_inputs, &region_args,
335                                &region_arg_types) ||
336       parser->parseOptionalAttrDict(state->attributes) ||
337       SetReplicateOpOperands(loc, parser, state, replicated_inputs,
338                              packed_inputs, region_arg_types, &n) ||
339       parser->parseRegion(body, region_args, region_arg_types))
340     return failure();
341 
342   // Add derived `operand_segment_sizes` attribute based on parsed operands.
343   if (!state->attributes.get(kOperandSegmentSizesAttr)) {
344     int32_t num_replicated_inputs = replicated_inputs.size() * n;
345     int32_t num_packed_inputs = packed_inputs.size();
346     auto attr = DenseIntElementsAttr::get(
347         VectorType::get({2}, parser->getBuilder().getI32Type()),
348         {num_replicated_inputs, num_packed_inputs});
349     state->addAttribute(kOperandSegmentSizesAttr, attr);
350   }
351 
352   // Ensure that the region is well formed: it contains at least a block with
353   // a ReturnOp terminator.
354   ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
355 
356   if (!llvm::hasSingleElement(body))
357     return parser->emitError(loc) << "expects a single block region";
358 
359   Operation& terminator = body.front().back();
360   if (!isa<ReturnOp>(terminator))
361     return parser->emitError(loc) << "expects a tf_device.return terminator";
362 
363   // Get the results type from the terminator type inside the replicate,
364   // replicated each by `n`.
365   state->types.reserve(terminator.getNumOperands() * n);
366   for (const auto& type : terminator.getOperandTypes())
367     state->types.append(n, type);
368 
369   return success();
370 }
371 
Print(ReplicateOp op,OpAsmPrinter * p)372 void Print(ReplicateOp op, OpAsmPrinter* p) {
373   *p << op.getOperationName();
374 
375   // Print comma separated operands of the following format:
376   //   replicated_input
377   //     [%a, ...] as %block_arg0: type
378   //   packed_input
379   //     %b as %block_arg1: type
380   const int32_t n = op.n();
381   const int32_t num_replicated_inputs =
382       (*op.operand_segment_sizes().int_value_begin()).getSExtValue();
383   const int32_t num_replicated_block_args = num_replicated_inputs / n;
384 
385   if (op.getNumOperands()) {
386     *p << '(';
387     Block& block = op.body().front();
388     interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
389       const int block_arg_num = arg.getArgNumber();
390       if (block_arg_num < num_replicated_block_args) {
391         *p << '[';
392         p->printOperands(
393             std::next(op.replicated_inputs().begin(), block_arg_num * n),
394             std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
395         *p << "]";
396       } else {
397         p->printOperand(*std::next(op.packed_inputs().begin(),
398                                    block_arg_num - num_replicated_block_args));
399       }
400       *p << " as " << arg << ": " << arg.getType();
401     });
402     *p << ')';
403   }
404 
405   // Skip derived `operand_segment_sizes` attribute as custom print format of
406   // operands holds enough information to calculate these variadic operand list
407   // lengths.
408   p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
409                                kOperandSegmentSizesAttr});
410   p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
411 }
412 
413 // Checks if two types are compatible (compatible shapes and same elemental
414 // type).
VerifyCompatibleTypes(Type a,Type b)415 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
416   if (failed(verifyCompatibleShape(a, b)) ||
417       getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
418     return failure();
419 
420   return success();
421 }
422 
Verify(ReplicateOp op)423 LogicalResult Verify(ReplicateOp op) {
424   int32_t n = op.n();
425 
426   // Check number of devices, if set, matches `n`.
427   if (op.devices().hasValue()) {
428     for (auto device_attr : op.devices().getValue().getValue()) {
429       auto device_list = device_attr.second.dyn_cast_or_null<ArrayAttr>();
430       if (!device_list)
431         return op.emitError()
432                << "expects 'devices' to be a map alias and device name list.";
433 
434       bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
435         return attr.dyn_cast_or_null<StringAttr>();
436       });
437       if (!is_device_string)
438         return op.emitOpError() << "expects 'devices' to be a consists of "
439                                    "string list as values.";
440 
441       if (device_list.size() != n)
442         return op.emitOpError()
443                << "expects number of devices (" << device_list.size()
444                << ") to be equal to 'n' (" << n << ")";
445     }
446   }
447 
448   Block& block = op.body().front();
449 
450   auto operand_segment_sizes = op.operand_segment_sizes();
451   const int32_t num_replicated_inputs =
452       operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
453   const int32_t num_packed_inputs =
454       operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
455 
456   if (num_replicated_inputs % n != 0)
457     return op.emitOpError()
458            << "expects number of replicated inputs (" << num_replicated_inputs
459            << ") to be evenly divisible by 'n' (" << n << ")";
460 
461   const int32_t num_replicated_block_args = num_replicated_inputs / n;
462   if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
463     return op.emitOpError()
464            << "expects number of block arguments (" << block.getNumArguments()
465            << ") to be equal to number of replicated inputs ("
466            << num_replicated_inputs << ") / 'n' (" << n
467            << ") + number of packed inputs (" << num_packed_inputs << ")";
468 
469   // Check input types match block argument types.
470   auto verify_operand_types = [&](BlockArgument block_arg,
471                                   int32_t op_operand_idx) -> LogicalResult {
472     Type op_operand_type = op.getOperand(op_operand_idx).getType();
473     if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
474       return op.emitOpError()
475              << "expects operand " << op_operand_idx << " (" << op_operand_type
476              << ") and block argument " << block_arg.getArgNumber() << " ("
477              << block_arg.getType() << ") to have compatible types";
478 
479     return success();
480   };
481   for (auto block_arg : block.getArguments()) {
482     if (block_arg.getArgNumber() < num_replicated_block_args) {
483       for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
484         if (failed(verify_operand_types(block_arg, i))) return failure();
485     } else {
486       const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
487                           num_replicated_inputs;
488       if (failed(verify_operand_types(block_arg, idx))) return failure();
489     }
490   }
491 
492   Operation& terminator = block.back();
493 
494   // Check number of results matches `n` * number of return operands.
495   if (op.getNumResults() != n * terminator.getNumOperands())
496     return op.emitOpError()
497            << "expects number of results (" << op.getNumResults()
498            << ") to be equal to 'n' * number of terminator operands (" << n
499            << " * " << terminator.getNumOperands() << ")";
500 
501   // Check replicated output types match return operand types.
502   for (auto operand_type_and_idx :
503        llvm::enumerate(terminator.getOperandTypes())) {
504     Type operand_type = operand_type_and_idx.value();
505     int32_t operand_idx = operand_type_and_idx.index();
506     for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
507       if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
508         return op.emitOpError() << "incompatible types for result " << i
509                                 << " and terminator operand " << operand_idx;
510   }
511 
512   return success();
513 }
514 
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)515 void BuildReplicateOp(
516     Builder* builder, OperationState* state, int n,
517     llvm::Optional<DictionaryAttr> devices,
518     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
519     ValueRange packed_inputs, TypeRange replica_output_types) {
520   DCHECK_GE(n, 2);
521   state->addAttribute("n", builder->getI32IntegerAttr(n));
522 
523   if (devices.hasValue()) state->addAttribute("devices", devices.getValue());
524 
525   Region* region = state->addRegion();
526   region->push_back(new Block);
527   Block& block = region->front();
528 
529   for (auto& replicated_input : replicated_inputs) {
530     DCHECK_EQ(llvm::size(replicated_input.first), n);
531     for (auto input : replicated_input.first) {
532       DCHECK(succeeded(
533           VerifyCompatibleTypes(input.getType(), replicated_input.second)));
534       state->addOperands(input);
535     }
536     block.addArgument(replicated_input.second);
537   }
538 
539   for (auto packed_input : packed_inputs) {
540     state->addOperands(packed_input);
541     block.addArgument(packed_input.getType());
542   }
543 
544   // Add derived `operand_segment_sizes` attribute.
545   int32_t num_replicated_inputs = replicated_inputs.size() * n;
546   int32_t num_packed_inputs = packed_inputs.size();
547   auto operand_segment_sizes =
548       DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
549                                 {num_replicated_inputs, num_packed_inputs});
550   state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
551 
552   for (const auto& output_type : replica_output_types)
553     state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
554 }
555 }  // anonymous namespace
556 
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)557 void ReplicateOp::build(
558     OpBuilder& builder, OperationState& state, int n,
559     const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
560         devices,
561     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
562     ValueRange packed_inputs, TypeRange replica_output_types) {
563   llvm::Optional<DictionaryAttr> devices_attr;
564   if (!devices.empty()) {
565     llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
566     device_list.reserve(devices.size());
567     for (auto alias_and_devices : devices) {
568       NamedAttribute device_name_attr = builder.getNamedAttr(
569           alias_and_devices.getFirst(),
570           builder.getStrArrayAttr(alias_and_devices.getSecond()));
571       device_list.emplace_back(device_name_attr);
572     }
573     devices_attr.emplace(builder.getDictionaryAttr(device_list));
574   }
575 
576   BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
577                    packed_inputs, replica_output_types);
578 }
579 
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)580 void ReplicateOp::build(
581     OpBuilder& builder, OperationState& state, int n,
582     llvm::Optional<DictionaryAttr> devices,
583     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
584     ValueRange packed_inputs, TypeRange replica_output_types) {
585   BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
586                    packed_inputs, replica_output_types);
587 }
588 
589 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()590 unsigned ReplicateOp::GetNumPackedBlockArguments() {
591   return packed_inputs().size();
592 }
593 
594 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()595 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
596   return GetBody().getNumArguments() - GetNumPackedBlockArguments();
597 }
598 
599 // Returns the replicated block arguments. A copy should be made if the
600 // replicate op is being modified.
GetReplicatedBlockArguments()601 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
602   return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
603 }
604 
605 // Returns the packed block arguments. A copy should be made if the replicate op
606 // is being modified.
GetPackedBlockArguments()607 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
608   return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
609 }
610 
611 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)612 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
613   assert(block_arg.getOwner() == &GetBody());
614   return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
615 }
616 
617 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)618 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
619   return !IsReplicatedBlockArgument(block_arg);
620 }
621 
622 // Returns the operand index of the operand being forwarded as a
623 // replicated/packed block argument for a given replica. This assumes a valid
624 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)625 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
626     BlockArgument block_arg, unsigned replica) {
627   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
628   if (operands.size() == 1) return operands.front().getOperandNumber();
629 
630   return operands[replica].getOperandNumber();
631 }
632 
633 // Returns the operand being forwarded as a replicated/packed block argument for
634 // a given replica. This assumes a valid block argument (of the replicate op)
635 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)636 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
637                                                      unsigned replica) {
638   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
639   if (operands.size() == 1) return operands.front().get();
640 
641   return operands[replica].get();
642 }
643 
644 // Returns the list of replica op operands that maps to the given block
645 // argument. Returns list with num_replicas elements for replicated operands
646 // and list with a single element for packed operands.
647 //
648 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)649 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
650     BlockArgument block_arg) {
651   assert(block_arg.getOwner() == &GetBody());
652 
653   unsigned arg_number = block_arg.getArgNumber();
654   unsigned num_replicated_args = GetNumReplicatedBlockArguments();
655   int32_t num_replicas = nAttr().getInt();
656   MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
657 
658   // All replicated arguments are before packed arguments so return replicated
659   // operands if the given argument is one of the replicated arguments.
660   if (arg_number < num_replicated_args)
661     return operands.slice(arg_number * num_replicas, num_replicas);
662 
663   operands = operands.drop_front(num_replicated_args * num_replicas);
664   arg_number -= num_replicated_args;
665   return operands.slice(arg_number, 1);
666 }
667 
668 // Checks if a tf_device.replicate wraps a single operation and the single
669 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()670 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
671 
672 //===----------------------------------------------------------------------===//
673 // Canonicalization patterns
674 //===----------------------------------------------------------------------===//
675 
676 //===----------------------------------------------------------------------===//
677 // tf_device.launch
678 //===----------------------------------------------------------------------===//
679 
680 namespace {
681 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
682 // results of the LaunchOp to the operands of the ReturnOp.
683 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
684   using OpRewritePattern<LaunchOp>::OpRewritePattern;
685 
matchAndRewritemlir::tf_device::__anon57a3a4cb0711::DropEmptyLaunch686   LogicalResult matchAndRewrite(LaunchOp op,
687                                 PatternRewriter& rewriter) const override {
688     Block& block = op.GetBody();
689     // Check if launch only has a return.
690     if (&block.front() != &block.back()) return failure();
691 
692     // Map launch results to return operands.
693     rewriter.replaceOp(op, block.front().getOperands());
694 
695     return success();
696   }
697 };
698 }  // anonymous namespace
699 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)700 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
701                                            MLIRContext* context) {
702   results.insert<DropEmptyLaunch>(context);
703 }
704 
705 }  // namespace tf_device
706 }  // namespace mlir
707 
708 //===----------------------------------------------------------------------===//
709 // TableGen'd op method definitions
710 //===----------------------------------------------------------------------===//
711 
712 #define GET_OP_CLASSES
713 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
714