1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15
16 #include "mlir/IR/BuiltinTypes.h"
17
18 using namespace mlir;
19
20 namespace mlir {
21 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
22 } // namespace mlir
23
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)24 LogicalResult mlir::detail::inferReturnTensorTypes(
25 function_ref<LogicalResult(
26 MLIRContext *, Optional<Location> location, ValueRange operands,
27 DictionaryAttr attributes, RegionRange regions,
28 SmallVectorImpl<ShapedTypeComponents> &retComponents)>
29 componentTypeFn,
30 MLIRContext *context, Optional<Location> location, ValueRange operands,
31 DictionaryAttr attributes, RegionRange regions,
32 SmallVectorImpl<Type> &inferredReturnTypes) {
33 SmallVector<ShapedTypeComponents, 2> retComponents;
34 if (failed(componentTypeFn(context, location, operands, attributes, regions,
35 retComponents)))
36 return failure();
37 for (auto shapeAndType : retComponents) {
38 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
39 if (shapeAndType.hasRank())
40 inferredReturnTypes.push_back(RankedTensorType::get(
41 shapeAndType.getDims(), shapeAndType.getElementType()));
42 else
43 inferredReturnTypes.push_back(
44 UnrankedTensorType::get(shapeAndType.getElementType()));
45 }
46 return success();
47 }
48
verifyInferredResultTypes(Operation * op)49 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
50 SmallVector<Type, 4> inferredReturnTypes;
51 auto retTypeFn = cast<InferTypeOpInterface>(op);
52 if (failed(retTypeFn.inferReturnTypes(
53 op->getContext(), op->getLoc(), op->getOperands(),
54 op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
55 return failure();
56 if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
57 op->getResultTypes()))
58 return op->emitOpError(
59 "inferred type incompatible with return type of operation");
60 return success();
61 }
62