1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/comparison_util.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "tensorflow/compiler/xla/util.h"
19 
20 namespace xla {
21 
ComparisonDirectionToString(Comparison::Direction direction)22 std::string ComparisonDirectionToString(Comparison::Direction direction) {
23   switch (direction) {
24     case Comparison::Direction::kEq:
25       return "EQ";
26     case Comparison::Direction::kNe:
27       return "NE";
28     case Comparison::Direction::kGe:
29       return "GE";
30     case Comparison::Direction::kGt:
31       return "GT";
32     case Comparison::Direction::kLe:
33       return "LE";
34     case Comparison::Direction::kLt:
35       return "LT";
36     default:
37       LOG(FATAL) << "Attempted to print uninitialized comparison direction";
38   }
39 }
40 
StringToComparisonDirection(absl::string_view direction_name)41 StatusOr<Comparison::Direction> StringToComparisonDirection(
42     absl::string_view direction_name) {
43   static auto* direction_map =
44       new absl::flat_hash_map<string, Comparison::Direction>({
45           {"EQ", Comparison::Direction::kEq},
46           {"NE", Comparison::Direction::kNe},
47           {"GE", Comparison::Direction::kGe},
48           {"GT", Comparison::Direction::kGt},
49           {"LE", Comparison::Direction::kLe},
50           {"LT", Comparison::Direction::kLt},
51       });
52   auto it = direction_map->find(direction_name);
53   if (it == direction_map->end()) {
54     return InvalidArgument("Unknown comparison direction: %s", direction_name);
55   }
56   return it->second;
57 }
58 
StringToComparisonType(absl::string_view compare_type_name)59 StatusOr<Comparison::Type> StringToComparisonType(
60     absl::string_view compare_type_name) {
61   static auto* type_map = new absl::flat_hash_map<string, Comparison::Type>({
62       {"FLOAT", Comparison::Type::kFloat},
63       {"TOTALORDER", Comparison::Type::kFloatTotalOrder},
64       {"SIGNED", Comparison::Type::kSigned},
65       {"UNSIGNED", Comparison::Type::kUnsigned},
66   });
67   auto it = type_map->find(compare_type_name);
68   if (it == type_map->end()) {
69     return InvalidArgument("Unknown comparison type: %s", compare_type_name);
70   }
71   return it->second;
72 }
73 
ComparisonTypeToString(Comparison::Type type)74 std::string ComparisonTypeToString(Comparison::Type type) {
75   switch (type) {
76     case Comparison::Type::kFloat:
77       return "FLOAT";
78     case Comparison::Type::kFloatTotalOrder:
79       return "TOTALORDER";
80     case Comparison::Type::kSigned:
81       return "SIGNED";
82     case Comparison::Type::kUnsigned:
83       return "UNSIGNED";
84     default:
85       LOG(FATAL) << "Attempted to print incomplete comparison type";
86   }
87 }
88 
Comparison(Direction dir,PrimitiveType type)89 Comparison::Comparison(Direction dir, PrimitiveType type)
90     : dir_(dir), type_(DefaultComparisonType(type)) {}
91 
DefaultComparisonType(PrimitiveType type)92 Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
93   switch (type) {
94     case S8:
95     case S16:
96     case S32:
97     case S64:
98       return Type::kSigned;
99     case PRED:
100     case U8:
101     case U16:
102     case U32:
103     case U64:
104       return Type::kUnsigned;
105     case F16:
106     case F32:
107     case BF16:
108     case F64:
109     case C64:
110     case C128:
111       return Type::kFloat;
112     default:
113       LOG(FATAL) << "Unsupported comparison mode."
114                  << PrimitiveType_Name(type) << "\n";
115   }
116 }
117 
Converse() const118 Comparison Comparison::Converse() const {
119   return Comparison(Converse(dir_), type_);
120 }
121 
Inverse() const122 absl::optional<Comparison> Comparison::Inverse() const {
123   switch (type_) {
124     case Type::kFloat:
125       // Floating-point comparisons don't have inverses unless total order is
126       // supported (e.g. comparison can return true if one operand is NaN).
127       return absl::nullopt;
128     case Type::kFloatTotalOrder:
129     case Type::kSigned:
130     case Type::kUnsigned:
131       return Comparison(Inverse(dir_), type_);
132   }
133 }
134 
IsReflexive() const135 bool Comparison::IsReflexive() const {
136   switch (dir_) {
137     case Direction::kEq:
138     case Direction::kGe:
139     case Direction::kLe:
140       return IsSigned() || IsUnsigned() || IsFloatTotalOrder();
141     case Direction::kNe:
142     case Direction::kGt:
143     case Direction::kLt:
144       return false;
145   }
146 }
147 
IsAntireflexive() const148 bool Comparison::IsAntireflexive() const {
149   switch (dir_) {
150     case Direction::kNe:
151       return IsSigned() || IsUnsigned() || IsFloatTotalOrder();
152     case Direction::kGt:
153     case Direction::kLt:
154       return true;
155     case Direction::kEq:
156     case Direction::kGe:
157     case Direction::kLe:
158       return false;
159   }
160 }
161 
Converse(Comparison::Direction dir)162 /* static */ Comparison::Direction Comparison::Converse(
163     Comparison::Direction dir) {
164   switch (dir) {
165     case Comparison::Direction::kEq:
166       return Comparison::Direction::kEq;
167     case Comparison::Direction::kNe:
168       return Comparison::Direction::kNe;
169     case Comparison::Direction::kGe:
170       return Comparison::Direction::kLe;
171     case Comparison::Direction::kGt:
172       return Comparison::Direction::kLt;
173     case Comparison::Direction::kLe:
174       return Comparison::Direction::kGe;
175     case Comparison::Direction::kLt:
176       return Comparison::Direction::kGt;
177   }
178 }
179 
Inverse(Comparison::Direction dir)180 /* static */ Comparison::Direction Comparison::Inverse(
181     Comparison::Direction dir) {
182   switch (dir) {
183     case Direction::kEq:
184       return Direction::kNe;
185     case Direction::kNe:
186       return Direction::kEq;
187     case Direction::kGe:
188       return Direction::kLt;
189     case Direction::kGt:
190       return Direction::kLe;
191     case Direction::kLe:
192       return Direction::kGt;
193     case Direction::kLt:
194       return Direction::kGe;
195   }
196 }
197 
ToString(std::string prefix1,std::string prefix2) const198 std::string Comparison::ToString(std::string prefix1,
199                                  std::string prefix2) const {
200   return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +
201          std::string(ComparisonTypeToString(type_));
202 }
203 }  // namespace xla
204