1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 
18 using namespace mlir;
19 
20 namespace mlir {
21 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
22 } // namespace mlir
23 
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)24 LogicalResult mlir::detail::inferReturnTensorTypes(
25     function_ref<LogicalResult(
26         MLIRContext *, Optional<Location> location, ValueRange operands,
27         DictionaryAttr attributes, RegionRange regions,
28         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
29         componentTypeFn,
30     MLIRContext *context, Optional<Location> location, ValueRange operands,
31     DictionaryAttr attributes, RegionRange regions,
32     SmallVectorImpl<Type> &inferredReturnTypes) {
33   SmallVector<ShapedTypeComponents, 2> retComponents;
34   if (failed(componentTypeFn(context, location, operands, attributes, regions,
35                              retComponents)))
36     return failure();
37   for (auto shapeAndType : retComponents) {
38     assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
39     if (shapeAndType.hasRank())
40       inferredReturnTypes.push_back(RankedTensorType::get(
41           shapeAndType.getDims(), shapeAndType.getElementType()));
42     else
43       inferredReturnTypes.push_back(
44           UnrankedTensorType::get(shapeAndType.getElementType()));
45   }
46   return success();
47 }
48 
verifyInferredResultTypes(Operation * op)49 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
50   SmallVector<Type, 4> inferredReturnTypes;
51   auto retTypeFn = cast<InferTypeOpInterface>(op);
52   if (failed(retTypeFn.inferReturnTypes(
53           op->getContext(), op->getLoc(), op->getOperands(),
54           op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
55     return failure();
56   if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
57                                          op->getResultTypes()))
58     return op->emitOpError(
59         "inferred type incompatible with return type of operation");
60   return success();
61 }
62