1 //===-- include/flang/Common/enum-set.h -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_COMMON_ENUM_SET_H_
10 #define FORTRAN_COMMON_ENUM_SET_H_
11 
12 // Implements a set of enums as a std::bitset<>.  APIs from bitset<> and set<>
13 // can be used on these sets, whichever might be more clear to the user.
14 // This class template facilitates the use of the more type-safe C++ "enum
15 // class" feature without loss of convenience.
16 
17 #include "constexpr-bitset.h"
18 #include "idioms.h"
19 #include <bitset>
20 #include <cstddef>
21 #include <initializer_list>
22 #include <optional>
23 #include <string>
24 #include <type_traits>
25 
26 namespace Fortran::common {
27 
28 template <typename ENUM, std::size_t BITS> class EnumSet {
29   static_assert(BITS > 0);
30 
31 public:
32   // When the bitset fits in a word, use a custom local bitset class that is
33   // more amenable to constexpr evaluation than the current std::bitset<>.
34   using bitsetType =
35       std::conditional_t<(BITS <= 64), common::BitSet<BITS>, std::bitset<BITS>>;
36   using enumerationType = ENUM;
37 
EnumSet()38   constexpr EnumSet() {}
EnumSet(const std::initializer_list<enumerationType> & enums)39   constexpr EnumSet(const std::initializer_list<enumerationType> &enums) {
40     for (auto it{enums.begin()}; it != enums.end(); ++it) {
41       set(*it);
42     }
43   }
44   constexpr EnumSet(const EnumSet &) = default;
45   constexpr EnumSet(EnumSet &&) = default;
46 
47   constexpr EnumSet &operator=(const EnumSet &) = default;
48   constexpr EnumSet &operator=(EnumSet &&) = default;
49 
bitset()50   const bitsetType &bitset() const { return bitset_; }
51 
52   constexpr EnumSet &operator&=(const EnumSet &that) {
53     bitset_ &= that.bitset_;
54     return *this;
55   }
56   constexpr EnumSet &operator&=(EnumSet &&that) {
57     bitset_ &= that.bitset_;
58     return *this;
59   }
60   constexpr EnumSet &operator|=(const EnumSet &that) {
61     bitset_ |= that.bitset_;
62     return *this;
63   }
64   constexpr EnumSet &operator|=(EnumSet &&that) {
65     bitset_ |= that.bitset_;
66     return *this;
67   }
68   constexpr EnumSet &operator^=(const EnumSet &that) {
69     bitset_ ^= that.bitset_;
70     return *this;
71   }
72   constexpr EnumSet &operator^=(EnumSet &&that) {
73     bitset_ ^= that.bitset_;
74     return *this;
75   }
76 
77   constexpr EnumSet operator~() const {
78     EnumSet result;
79     result.bitset_ = ~bitset_;
80     return result;
81   }
82   constexpr EnumSet operator&(const EnumSet &that) const {
83     EnumSet result{*this};
84     result.bitset_ &= that.bitset_;
85     return result;
86   }
87   constexpr EnumSet operator&(EnumSet &&that) const {
88     EnumSet result{*this};
89     result.bitset_ &= that.bitset_;
90     return result;
91   }
92   constexpr EnumSet operator|(const EnumSet &that) const {
93     EnumSet result{*this};
94     result.bitset_ |= that.bitset_;
95     return result;
96   }
97   constexpr EnumSet operator|(EnumSet &&that) const {
98     EnumSet result{*this};
99     result.bitset_ |= that.bitset_;
100     return result;
101   }
102   constexpr EnumSet operator^(const EnumSet &that) const {
103     EnumSet result{*this};
104     result.bitset_ ^= that.bitset_;
105     return result;
106   }
107   constexpr EnumSet operator^(EnumSet &&that) const {
108     EnumSet result{*this};
109     result.bitset_ ^= that.bitset_;
110     return result;
111   }
112 
113   constexpr EnumSet operator+(enumerationType v) const {
114     return {*this | EnumSet{v}};
115   }
116   constexpr EnumSet operator-(enumerationType v) const {
117     return {*this & ~EnumSet{v}};
118   }
119 
120   constexpr bool operator==(const EnumSet &that) const {
121     return bitset_ == that.bitset_;
122   }
123   constexpr bool operator==(EnumSet &&that) const {
124     return bitset_ == that.bitset_;
125   }
126   constexpr bool operator!=(const EnumSet &that) const {
127     return bitset_ != that.bitset_;
128   }
129   constexpr bool operator!=(EnumSet &&that) const {
130     return bitset_ != that.bitset_;
131   }
132 
133   // N.B. std::bitset<> has size() for max_size(), but that's not the same
134   // thing as std::set<>::size(), which is an element count.
max_size()135   static constexpr std::size_t max_size() { return BITS; }
test(enumerationType x)136   constexpr bool test(enumerationType x) const {
137     return bitset_.test(static_cast<std::size_t>(x));
138   }
all()139   constexpr bool all() const { return bitset_.all(); }
any()140   constexpr bool any() const { return bitset_.any(); }
none()141   constexpr bool none() const { return bitset_.none(); }
142 
143   // N.B. std::bitset<> has count() as an element count, while
144   // std::set<>::count(x) returns 0 or 1 to indicate presence.
count()145   constexpr std::size_t count() const { return bitset_.count(); }
count(enumerationType x)146   constexpr std::size_t count(enumerationType x) const {
147     return test(x) ? 1 : 0;
148   }
149 
set()150   constexpr EnumSet &set() {
151     bitset_.set();
152     return *this;
153   }
154   constexpr EnumSet &set(enumerationType x, bool value = true) {
155     bitset_.set(static_cast<std::size_t>(x), value);
156     return *this;
157   }
reset()158   constexpr EnumSet &reset() {
159     bitset_.reset();
160     return *this;
161   }
reset(enumerationType x)162   constexpr EnumSet &reset(enumerationType x) {
163     bitset_.reset(static_cast<std::size_t>(x));
164     return *this;
165   }
flip()166   constexpr EnumSet &flip() {
167     bitset_.flip();
168     return *this;
169   }
flip(enumerationType x)170   constexpr EnumSet &flip(enumerationType x) {
171     bitset_.flip(static_cast<std::size_t>(x));
172     return *this;
173   }
174 
empty()175   constexpr bool empty() const { return none(); }
clear()176   void clear() { reset(); }
insert(enumerationType x)177   void insert(enumerationType x) { set(x); }
insert(enumerationType && x)178   void insert(enumerationType &&x) { set(x); }
emplace(enumerationType && x)179   void emplace(enumerationType &&x) { set(x); }
erase(enumerationType x)180   void erase(enumerationType x) { reset(x); }
erase(enumerationType && x)181   void erase(enumerationType &&x) { reset(x); }
182 
LeastElement()183   constexpr std::optional<enumerationType> LeastElement() const {
184     if (empty()) {
185       return std::nullopt;
186     } else if constexpr (std::is_same_v<bitsetType, common::BitSet<BITS>>) {
187       return {static_cast<enumerationType>(bitset_.LeastElement().value())};
188     } else {
189       // std::bitset: just iterate
190       for (std::size_t j{0}; j < BITS; ++j) {
191         auto enumerator{static_cast<enumerationType>(j)};
192         if (bitset_.test(j)) {
193           return {enumerator};
194         }
195       }
196       die("EnumSet::LeastElement(): no bit found in non-empty std::bitset");
197     }
198   }
199 
IterateOverMembers(const FUNC & f)200   template <typename FUNC> void IterateOverMembers(const FUNC &f) const {
201     EnumSet copy{*this};
202     while (auto least{copy.LeastElement()}) {
203       f(*least);
204       copy.erase(*least);
205     }
206   }
207 
208   template <typename STREAM>
Dump(STREAM & o,std::string EnumToString (enumerationType))209   STREAM &Dump(STREAM &o, std::string EnumToString(enumerationType)) const {
210     char sep{'{'};
211     IterateOverMembers([&](auto e) {
212       o << sep << EnumToString(e);
213       sep = ',';
214     });
215     return o << (sep == '{' ? "{}" : "}");
216   }
217 
218 private:
219   bitsetType bitset_{};
220 };
221 } // namespace Fortran::common
222 
223 template <typename ENUM, std::size_t values>
224 struct std::hash<Fortran::common::EnumSet<ENUM, values>> {
225   std::size_t operator()(
226       const Fortran::common::EnumSet<ENUM, values> &x) const {
227     return std::hash(x.bitset());
228   }
229 };
230 #endif // FORTRAN_COMMON_ENUM_SET_H_
231