1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/OpImplementation.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
40 #include "mlir/IR/Types.h" // from @llvm-project
41 #include "mlir/IR/UseDefLists.h" // from @llvm-project
42 #include "mlir/IR/Value.h" // from @llvm-project
43 #include "mlir/Support/LLVM.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48
49 namespace mlir {
50 namespace tf_device {
51
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks
62 //===--------------------------------------------------------------------===//
63
64 // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface65 bool isLegalToInline(Operation* call, Operation* callable,
66 bool wouldBeCloned) const final {
67 return true;
68 }
69 // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface70 bool isLegalToInline(Operation*, Region*, bool,
71 BlockAndValueMapping&) const final {
72 // For now, enable inlining all operations.
73 return true;
74 }
75
76 //===--------------------------------------------------------------------===//
77 // Transformation Hooks
78 //===--------------------------------------------------------------------===//
79
80 // Attempts to materialize a conversion for a type mismatch between a call
81 // from this dialect, and a callable region. This method should generate an
82 // operation that takes 'input' as the only operand, and produces a single
83 // result of 'resultType'. If a conversion can not be generated, nullptr
84 // should be returned.
85 // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon57a3a4cb0111::TFInlinerInterface86 Operation* materializeCallConversion(OpBuilder& builder, Value input,
87 Type result_type,
88 Location conversion_loc) const final {
89 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
90 return nullptr;
91 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
92 /*truncate=*/builder.getBoolAttr(false));
93 }
94 };
95
96 // Checks if a block wraps a single operation and the single operation results
97 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)98 bool BlockWrapsSingleOp(Block* block) {
99 auto body = block->without_terminator();
100 if (!hasSingleElement(body)) return false;
101
102 Operation& wrapped_op = *body.begin();
103 Operation* terminator = block->getTerminator();
104 return wrapped_op.getNumResults() == terminator->getNumOperands() &&
105 std::equal(wrapped_op.getResults().begin(),
106 wrapped_op.getResults().end(),
107 terminator->getOperands().begin());
108 }
109 } // end anonymous namespace
110
TensorFlowDeviceDialect(MLIRContext * context)111 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
112 : Dialect(/*name=*/"tf_device", context,
113 TypeID::get<TensorFlowDeviceDialect>()) {
114 addOperations<
115 #define GET_OP_LIST
116 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
117 >();
118
119 addInterfaces<TFInlinerInterface>();
120 }
121
122 //===----------------------------------------------------------------------===//
123 // tf_device.launch
124 //===----------------------------------------------------------------------===//
125
126 // Checks if a tf_device.launch wraps a single operation and the single
127 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()128 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
129
130 //===----------------------------------------------------------------------===//
131 // tf_device.parallel_execute
132 //===----------------------------------------------------------------------===//
133
134 namespace {
135
Verify(ParallelExecuteOp op)136 LogicalResult Verify(ParallelExecuteOp op) {
137 const auto& regions = op.getOperation()->getRegions();
138 if (regions.size() < 2) {
139 return op.emitOpError() << "must have at least two regions.";
140 }
141
142 int output_index = 0;
143 for (auto& region_and_index : llvm::enumerate(regions)) {
144 auto& region = region_and_index.value();
145 auto* region_terminator = region.front().getTerminator();
146
147 // Check that output types of regions match return operand types.
148 for (auto result_type : region_terminator->getOperandTypes()) {
149 if (result_type !=
150 op.getOperation()->getResult(output_index++).getType()) {
151 return op.emitOpError() << "output types must be a concatenated "
152 << "list of output types for each regions.";
153 }
154 }
155 }
156
157 // Check that total number of outputs from regions match the output types of
158 // the parallel_execute op.
159 const int num_output_types = op.getOperation()->getNumResults();
160 if (num_output_types != output_index) {
161 return op.emitOpError()
162 << "number of output types (" << num_output_types << ") "
163 << "must match the total number of outputs from all "
164 << "regions (" << output_index << ").";
165 }
166
167 return success();
168 }
169
170 } // namespace
171
172 // static
build(OpBuilder & builder,OperationState & state,int num_regions,llvm::ArrayRef<Type> output_types)173 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
174 int num_regions,
175 llvm::ArrayRef<Type> output_types) {
176 DCHECK_GE(num_regions, 2);
177 for (int i = 0; i < num_regions; ++i) {
178 Region* region = state.addRegion();
179 region->push_back(new Block);
180 }
181 state.addTypes(output_types);
182 }
183
GetRegionBlockWithIndex(unsigned index)184 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
185 return getOperation()->getRegion(index).front();
186 }
187
GetRegionOutputs(unsigned region_index)188 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
189 unsigned region_index) {
190 int num_region_results =
191 GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
192
193 int return_value_offset = 0;
194 for (int region_id = 0; region_id < region_index; ++region_id)
195 return_value_offset +=
196 GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
197
198 Operation::result_range region_results(getOperation(),
199 /*startIndex=*/return_value_offset,
200 /*count=*/num_region_results);
201 return region_results;
202 }
203
RegionWrapsSingleOp(unsigned index)204 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
205 return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
206 }
207
208 //===----------------------------------------------------------------------===//
209 // tf_device.replicate
210 //===----------------------------------------------------------------------===//
211
212 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)213 ParseResult ParseReplicateOpOperands(
214 OpAsmParser* parser, OperationState* state,
215 llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
216 replicated_inputs,
217 llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
218 llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
219 llvm::SmallVectorImpl<Type>* region_arg_types) {
220 // No operands or empty operand list.
221 bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
222 if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
223 return success();
224
225 // Parse comma separated operands of the following format:
226 // replicated_input
227 // [%a, ...] as %block_arg0: type
228 // packed_input
229 // %b as %block_arg1: type
230 //
231 // Replicated inputs are placed before packed inputs when forming the op.
232 llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
233 llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
234 llvm::SmallVector<Type, 8> replicated_region_arg_types;
235 llvm::SmallVector<Type, 8> packed_region_arg_types;
236 do {
237 OpAsmParser::OperandType operand_type;
238 if (parser->parseOptionalOperand(operand_type).hasValue()) {
239 packed_inputs->emplace_back(operand_type);
240 if (parser->parseKeyword("as",
241 " between packed input and block argument") ||
242 parser->parseRegionArgument(packed_region_args.emplace_back()) ||
243 parser->parseColonType(packed_region_arg_types.emplace_back()))
244 return failure();
245 } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
246 OpAsmParser::Delimiter::Square) ||
247 parser->parseKeyword(
248 "as", " between replicated inputs and block argument") ||
249 parser->parseRegionArgument(
250 replicated_region_args.emplace_back()) ||
251 parser->parseColonType(
252 replicated_region_arg_types.emplace_back())) {
253 return failure();
254 }
255 } while (succeeded(parser->parseOptionalComma()));
256
257 region_args->reserve(replicated_region_args.size() +
258 packed_region_args.size());
259 region_args->append(replicated_region_args.begin(),
260 replicated_region_args.end());
261 region_args->append(packed_region_args.begin(), packed_region_args.end());
262
263 region_arg_types->reserve(replicated_region_arg_types.size() +
264 packed_region_arg_types.size());
265 region_arg_types->append(replicated_region_arg_types.begin(),
266 replicated_region_arg_types.end());
267 region_arg_types->append(packed_region_arg_types.begin(),
268 packed_region_arg_types.end());
269
270 // Parse remaining `)` surrounding operands.
271 return parser->parseRParen();
272 }
273
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)274 ParseResult SetReplicateOpOperands(
275 llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
276 llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
277 replicated_inputs,
278 llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
279 llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
280 for (const auto& attr : state->attributes)
281 if (attr.first.strref() == "n")
282 if (auto n_attr = attr.second.dyn_cast<IntegerAttr>())
283 *n = n_attr.getInt();
284
285 if (*n < 2)
286 return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
287
288 if (replicated_inputs.empty() && packed_inputs.empty()) return success();
289
290 for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
291 const int32_t idx = replicated_input_and_idx.index();
292 const auto& replicated_input = replicated_input_and_idx.value();
293 // Check if replicated input matches `n`.
294 if (replicated_input.size() != *n)
295 return parser->emitError(loc)
296 << "expects number of operands for replicated input " << idx
297 << " to be 'n' (" << *n << "), got " << replicated_input.size();
298
299 // Resolve replicated input and block argument type.
300 if (parser->resolveOperands(replicated_input, region_arg_types[idx],
301 state->operands))
302 return failure();
303 }
304
305 const int32_t num_replicated_block_args = replicated_inputs.size();
306 for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
307 const int32_t idx = packed_input_and_idx.index();
308 const auto& packed_input = packed_input_and_idx.value();
309
310 // Resolve packed input and block argument type.
311 if (parser->resolveOperand(
312 packed_input, region_arg_types[idx + num_replicated_block_args],
313 state->operands))
314 return failure();
315 }
316
317 return success();
318 }
319
320 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
321
ParseReplicateOp(OpAsmParser * parser,OperationState * state)322 ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
323 llvm::SMLoc loc = parser->getCurrentLocation();
324
325 // Parse operands, attributes, and region of op.
326 llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
327 replicated_inputs;
328 llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
329 llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
330 llvm::SmallVector<Type, 8> region_arg_types;
331 int32_t n = 0;
332 Region& body = *state->addRegion();
333 if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
334 &packed_inputs, ®ion_args,
335 ®ion_arg_types) ||
336 parser->parseOptionalAttrDict(state->attributes) ||
337 SetReplicateOpOperands(loc, parser, state, replicated_inputs,
338 packed_inputs, region_arg_types, &n) ||
339 parser->parseRegion(body, region_args, region_arg_types))
340 return failure();
341
342 // Add derived `operand_segment_sizes` attribute based on parsed operands.
343 if (!state->attributes.get(kOperandSegmentSizesAttr)) {
344 int32_t num_replicated_inputs = replicated_inputs.size() * n;
345 int32_t num_packed_inputs = packed_inputs.size();
346 auto attr = DenseIntElementsAttr::get(
347 VectorType::get({2}, parser->getBuilder().getI32Type()),
348 {num_replicated_inputs, num_packed_inputs});
349 state->addAttribute(kOperandSegmentSizesAttr, attr);
350 }
351
352 // Ensure that the region is well formed: it contains at least a block with
353 // a ReturnOp terminator.
354 ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
355
356 if (!llvm::hasSingleElement(body))
357 return parser->emitError(loc) << "expects a single block region";
358
359 Operation& terminator = body.front().back();
360 if (!isa<ReturnOp>(terminator))
361 return parser->emitError(loc) << "expects a tf_device.return terminator";
362
363 // Get the results type from the terminator type inside the replicate,
364 // replicated each by `n`.
365 state->types.reserve(terminator.getNumOperands() * n);
366 for (const auto& type : terminator.getOperandTypes())
367 state->types.append(n, type);
368
369 return success();
370 }
371
Print(ReplicateOp op,OpAsmPrinter * p)372 void Print(ReplicateOp op, OpAsmPrinter* p) {
373 *p << op.getOperationName();
374
375 // Print comma separated operands of the following format:
376 // replicated_input
377 // [%a, ...] as %block_arg0: type
378 // packed_input
379 // %b as %block_arg1: type
380 const int32_t n = op.n();
381 const int32_t num_replicated_inputs =
382 (*op.operand_segment_sizes().int_value_begin()).getSExtValue();
383 const int32_t num_replicated_block_args = num_replicated_inputs / n;
384
385 if (op.getNumOperands()) {
386 *p << '(';
387 Block& block = op.body().front();
388 interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
389 const int block_arg_num = arg.getArgNumber();
390 if (block_arg_num < num_replicated_block_args) {
391 *p << '[';
392 p->printOperands(
393 std::next(op.replicated_inputs().begin(), block_arg_num * n),
394 std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
395 *p << "]";
396 } else {
397 p->printOperand(*std::next(op.packed_inputs().begin(),
398 block_arg_num - num_replicated_block_args));
399 }
400 *p << " as " << arg << ": " << arg.getType();
401 });
402 *p << ')';
403 }
404
405 // Skip derived `operand_segment_sizes` attribute as custom print format of
406 // operands holds enough information to calculate these variadic operand list
407 // lengths.
408 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
409 kOperandSegmentSizesAttr});
410 p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
411 }
412
413 // Checks if two types are compatible (compatible shapes and same elemental
414 // type).
VerifyCompatibleTypes(Type a,Type b)415 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
416 if (failed(verifyCompatibleShape(a, b)) ||
417 getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
418 return failure();
419
420 return success();
421 }
422
Verify(ReplicateOp op)423 LogicalResult Verify(ReplicateOp op) {
424 int32_t n = op.n();
425
426 // Check number of devices, if set, matches `n`.
427 if (op.devices().hasValue()) {
428 for (auto device_attr : op.devices().getValue().getValue()) {
429 auto device_list = device_attr.second.dyn_cast_or_null<ArrayAttr>();
430 if (!device_list)
431 return op.emitError()
432 << "expects 'devices' to be a map alias and device name list.";
433
434 bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
435 return attr.dyn_cast_or_null<StringAttr>();
436 });
437 if (!is_device_string)
438 return op.emitOpError() << "expects 'devices' to be a consists of "
439 "string list as values.";
440
441 if (device_list.size() != n)
442 return op.emitOpError()
443 << "expects number of devices (" << device_list.size()
444 << ") to be equal to 'n' (" << n << ")";
445 }
446 }
447
448 Block& block = op.body().front();
449
450 auto operand_segment_sizes = op.operand_segment_sizes();
451 const int32_t num_replicated_inputs =
452 operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
453 const int32_t num_packed_inputs =
454 operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
455
456 if (num_replicated_inputs % n != 0)
457 return op.emitOpError()
458 << "expects number of replicated inputs (" << num_replicated_inputs
459 << ") to be evenly divisible by 'n' (" << n << ")";
460
461 const int32_t num_replicated_block_args = num_replicated_inputs / n;
462 if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
463 return op.emitOpError()
464 << "expects number of block arguments (" << block.getNumArguments()
465 << ") to be equal to number of replicated inputs ("
466 << num_replicated_inputs << ") / 'n' (" << n
467 << ") + number of packed inputs (" << num_packed_inputs << ")";
468
469 // Check input types match block argument types.
470 auto verify_operand_types = [&](BlockArgument block_arg,
471 int32_t op_operand_idx) -> LogicalResult {
472 Type op_operand_type = op.getOperand(op_operand_idx).getType();
473 if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
474 return op.emitOpError()
475 << "expects operand " << op_operand_idx << " (" << op_operand_type
476 << ") and block argument " << block_arg.getArgNumber() << " ("
477 << block_arg.getType() << ") to have compatible types";
478
479 return success();
480 };
481 for (auto block_arg : block.getArguments()) {
482 if (block_arg.getArgNumber() < num_replicated_block_args) {
483 for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
484 if (failed(verify_operand_types(block_arg, i))) return failure();
485 } else {
486 const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
487 num_replicated_inputs;
488 if (failed(verify_operand_types(block_arg, idx))) return failure();
489 }
490 }
491
492 Operation& terminator = block.back();
493
494 // Check number of results matches `n` * number of return operands.
495 if (op.getNumResults() != n * terminator.getNumOperands())
496 return op.emitOpError()
497 << "expects number of results (" << op.getNumResults()
498 << ") to be equal to 'n' * number of terminator operands (" << n
499 << " * " << terminator.getNumOperands() << ")";
500
501 // Check replicated output types match return operand types.
502 for (auto operand_type_and_idx :
503 llvm::enumerate(terminator.getOperandTypes())) {
504 Type operand_type = operand_type_and_idx.value();
505 int32_t operand_idx = operand_type_and_idx.index();
506 for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
507 if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
508 return op.emitOpError() << "incompatible types for result " << i
509 << " and terminator operand " << operand_idx;
510 }
511
512 return success();
513 }
514
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)515 void BuildReplicateOp(
516 Builder* builder, OperationState* state, int n,
517 llvm::Optional<DictionaryAttr> devices,
518 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
519 ValueRange packed_inputs, TypeRange replica_output_types) {
520 DCHECK_GE(n, 2);
521 state->addAttribute("n", builder->getI32IntegerAttr(n));
522
523 if (devices.hasValue()) state->addAttribute("devices", devices.getValue());
524
525 Region* region = state->addRegion();
526 region->push_back(new Block);
527 Block& block = region->front();
528
529 for (auto& replicated_input : replicated_inputs) {
530 DCHECK_EQ(llvm::size(replicated_input.first), n);
531 for (auto input : replicated_input.first) {
532 DCHECK(succeeded(
533 VerifyCompatibleTypes(input.getType(), replicated_input.second)));
534 state->addOperands(input);
535 }
536 block.addArgument(replicated_input.second);
537 }
538
539 for (auto packed_input : packed_inputs) {
540 state->addOperands(packed_input);
541 block.addArgument(packed_input.getType());
542 }
543
544 // Add derived `operand_segment_sizes` attribute.
545 int32_t num_replicated_inputs = replicated_inputs.size() * n;
546 int32_t num_packed_inputs = packed_inputs.size();
547 auto operand_segment_sizes =
548 DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
549 {num_replicated_inputs, num_packed_inputs});
550 state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
551
552 for (const auto& output_type : replica_output_types)
553 state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
554 }
555 } // anonymous namespace
556
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)557 void ReplicateOp::build(
558 OpBuilder& builder, OperationState& state, int n,
559 const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
560 devices,
561 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
562 ValueRange packed_inputs, TypeRange replica_output_types) {
563 llvm::Optional<DictionaryAttr> devices_attr;
564 if (!devices.empty()) {
565 llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
566 device_list.reserve(devices.size());
567 for (auto alias_and_devices : devices) {
568 NamedAttribute device_name_attr = builder.getNamedAttr(
569 alias_and_devices.getFirst(),
570 builder.getStrArrayAttr(alias_and_devices.getSecond()));
571 device_list.emplace_back(device_name_attr);
572 }
573 devices_attr.emplace(builder.getDictionaryAttr(device_list));
574 }
575
576 BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
577 packed_inputs, replica_output_types);
578 }
579
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)580 void ReplicateOp::build(
581 OpBuilder& builder, OperationState& state, int n,
582 llvm::Optional<DictionaryAttr> devices,
583 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
584 ValueRange packed_inputs, TypeRange replica_output_types) {
585 BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
586 packed_inputs, replica_output_types);
587 }
588
589 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()590 unsigned ReplicateOp::GetNumPackedBlockArguments() {
591 return packed_inputs().size();
592 }
593
594 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()595 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
596 return GetBody().getNumArguments() - GetNumPackedBlockArguments();
597 }
598
599 // Returns the replicated block arguments. A copy should be made if the
600 // replicate op is being modified.
GetReplicatedBlockArguments()601 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
602 return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
603 }
604
605 // Returns the packed block arguments. A copy should be made if the replicate op
606 // is being modified.
GetPackedBlockArguments()607 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
608 return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
609 }
610
611 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)612 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
613 assert(block_arg.getOwner() == &GetBody());
614 return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
615 }
616
617 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)618 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
619 return !IsReplicatedBlockArgument(block_arg);
620 }
621
622 // Returns the operand index of the operand being forwarded as a
623 // replicated/packed block argument for a given replica. This assumes a valid
624 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)625 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
626 BlockArgument block_arg, unsigned replica) {
627 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
628 if (operands.size() == 1) return operands.front().getOperandNumber();
629
630 return operands[replica].getOperandNumber();
631 }
632
633 // Returns the operand being forwarded as a replicated/packed block argument for
634 // a given replica. This assumes a valid block argument (of the replicate op)
635 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)636 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
637 unsigned replica) {
638 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
639 if (operands.size() == 1) return operands.front().get();
640
641 return operands[replica].get();
642 }
643
644 // Returns the list of replica op operands that maps to the given block
645 // argument. Returns list with num_replicas elements for replicated operands
646 // and list with a single element for packed operands.
647 //
648 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)649 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
650 BlockArgument block_arg) {
651 assert(block_arg.getOwner() == &GetBody());
652
653 unsigned arg_number = block_arg.getArgNumber();
654 unsigned num_replicated_args = GetNumReplicatedBlockArguments();
655 int32_t num_replicas = nAttr().getInt();
656 MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
657
658 // All replicated arguments are before packed arguments so return replicated
659 // operands if the given argument is one of the replicated arguments.
660 if (arg_number < num_replicated_args)
661 return operands.slice(arg_number * num_replicas, num_replicas);
662
663 operands = operands.drop_front(num_replicated_args * num_replicas);
664 arg_number -= num_replicated_args;
665 return operands.slice(arg_number, 1);
666 }
667
668 // Checks if a tf_device.replicate wraps a single operation and the single
669 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()670 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
671
672 //===----------------------------------------------------------------------===//
673 // Canonicalization patterns
674 //===----------------------------------------------------------------------===//
675
676 //===----------------------------------------------------------------------===//
677 // tf_device.launch
678 //===----------------------------------------------------------------------===//
679
680 namespace {
681 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
682 // results of the LaunchOp to the operands of the ReturnOp.
683 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
684 using OpRewritePattern<LaunchOp>::OpRewritePattern;
685
matchAndRewritemlir::tf_device::__anon57a3a4cb0711::DropEmptyLaunch686 LogicalResult matchAndRewrite(LaunchOp op,
687 PatternRewriter& rewriter) const override {
688 Block& block = op.GetBody();
689 // Check if launch only has a return.
690 if (&block.front() != &block.back()) return failure();
691
692 // Map launch results to return operands.
693 rewriter.replaceOp(op, block.front().getOperands());
694
695 return success();
696 }
697 };
698 } // anonymous namespace
699
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)700 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
701 MLIRContext* context) {
702 results.insert<DropEmptyLaunch>(context);
703 }
704
705 } // namespace tf_device
706 } // namespace mlir
707
708 //===----------------------------------------------------------------------===//
709 // TableGen'd op method definitions
710 //===----------------------------------------------------------------------===//
711
712 #define GET_OP_CLASSES
713 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
714