1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15 
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/PointerUnion.h"
19 
20 namespace mlir {
21 class OperandRange;
22 class ResultRange;
23 class Type;
24 class Value;
25 class ValueRange;
26 template <typename ValueRangeT>
27 class ValueTypeRange;
28 
29 //===----------------------------------------------------------------------===//
30 // TypeRange
31 
32 /// This class provides an abstraction over the various different ranges of
33 /// value types. In many cases, this prevents the need to explicitly materialize
34 /// a SmallVector/std::vector. This class should be used in places that are not
35 /// suitable for a more derived type (e.g. ArrayRef) or a template range
36 /// parameter.
37 class TypeRange
38     : public llvm::detail::indexed_accessor_range_base<
39           TypeRange,
40           llvm::PointerUnion<const Value *, const Type *, OpOperand *>, Type,
41           Type, Type> {
42 public:
43   using RangeBaseT::RangeBaseT;
44   TypeRange(ArrayRef<Type> types = llvm::None);
45   explicit TypeRange(OperandRange values);
46   explicit TypeRange(ResultRange values);
47   explicit TypeRange(ValueRange values);
48   explicit TypeRange(ArrayRef<Value> values);
TypeRange(ArrayRef<BlockArgument> values)49   explicit TypeRange(ArrayRef<BlockArgument> values)
50       : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
51   template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)52   TypeRange(ValueTypeRange<ValueRangeT> values)
53       : TypeRange(ValueRangeT(values.begin().getCurrent(),
54                               values.end().getCurrent())) {}
55   template <typename Arg,
56             typename = typename std::enable_if_t<
57                 std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg && arg)58   TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)59   TypeRange(std::initializer_list<Type> types)
60       : TypeRange(ArrayRef<Type>(types)) {}
61 
62 private:
63   /// The owner of the range is either:
64   /// * A pointer to the first element of an array of values.
65   /// * A pointer to the first element of an array of types.
66   /// * A pointer to the first element of an array of operands.
67   using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *>;
68 
69   /// See `llvm::detail::indexed_accessor_range_base` for details.
70   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
71   /// See `llvm::detail::indexed_accessor_range_base` for details.
72   static Type dereference_iterator(OwnerT object, ptrdiff_t index);
73 
74   /// Allow access to `offset_base` and `dereference_iterator`.
75   friend RangeBaseT;
76 };
77 
78 /// Make TypeRange hashable.
hash_value(TypeRange arg)79 inline ::llvm::hash_code hash_value(TypeRange arg) {
80   return ::llvm::hash_combine_range(arg.begin(), arg.end());
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // ValueTypeRange
85 
86 /// This class implements iteration on the types of a given range of values.
87 template <typename ValueIteratorT>
88 class ValueTypeIterator final
89     : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
unwrap(Value value)90   static Type unwrap(Value value) { return value.getType(); }
91 
92 public:
93   using reference = Type;
94 
95   /// Provide a const dereference method.
96   Type operator*() const { return unwrap(*this->I); }
97 
98   /// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)99   ValueTypeIterator(ValueIteratorT it)
100       : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
101 };
102 
103 /// This class implements iteration on the types of a given range of values.
104 template <typename ValueRangeT>
105 class ValueTypeRange final
106     : public llvm::iterator_range<
107           ValueTypeIterator<typename ValueRangeT::iterator>> {
108 public:
109   using llvm::iterator_range<
110       ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
111   template <typename Container>
ValueTypeRange(Container && c)112   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
113 
114   /// Compare this range with another.
115   template <typename OtherT>
116   bool operator==(const OtherT &other) const {
117     return llvm::size(*this) == llvm::size(other) &&
118            std::equal(this->begin(), this->end(), other.begin());
119   }
120   template <typename OtherT>
121   bool operator!=(const OtherT &other) const {
122     return !(*this == other);
123   }
124 };
125 
126 template <typename RangeT>
127 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
128   return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
129          std::equal(lhs.begin(), lhs.end(), rhs.begin());
130 }
131 
132 } // namespace mlir
133 
134 namespace llvm {
135 
136 // Provide DenseMapInfo for TypeRange.
137 template <>
138 struct DenseMapInfo<mlir::TypeRange> {
139   static mlir::TypeRange getEmptyKey() {
140     return mlir::TypeRange(getEmptyKeyPointer(), 0);
141   }
142 
143   static mlir::TypeRange getTombstoneKey() {
144     return mlir::TypeRange(getTombstoneKeyPointer(), 0);
145   }
146 
147   static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
148 
149   static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
150     if (isEmptyKey(rhs))
151       return isEmptyKey(lhs);
152     if (isTombstoneKey(rhs))
153       return isTombstoneKey(lhs);
154     return lhs == rhs;
155   }
156 
157 private:
158   static const mlir::Type *getEmptyKeyPointer() {
159     return DenseMapInfo<mlir::Type *>::getEmptyKey();
160   }
161 
162   static const mlir::Type *getTombstoneKeyPointer() {
163     return DenseMapInfo<mlir::Type *>::getTombstoneKey();
164   }
165 
166   static bool isEmptyKey(mlir::TypeRange range) {
167     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
168       return type == getEmptyKeyPointer();
169     return false;
170   }
171 
172   static bool isTombstoneKey(mlir::TypeRange range) {
173     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
174       return type == getTombstoneKeyPointer();
175     return false;
176   }
177 };
178 
179 } // namespace llvm
180 
181 #endif // MLIR_IR_TYPERANGE_H
182