1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
26 #include "mlir/Dialect/Complex/IR/Complex.h"
27 #include "mlir/Dialect/Math/IR/Math.h"
28 #include "mlir/Dialect/SCF/SCF.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "mlir/IR/TypeUtilities.h"
33 
34 namespace mlir {
35 namespace lmhlo {
36 namespace impl {
37 
38 // A struct to map LhloBinaryOpTy type to the corresponding floating-point and
39 // integer scalar operation types.
40 template <typename LhloBinaryOpTy>
41 struct LhloToScalarOp;
42 
43 template <>
44 struct LhloToScalarOp<lmhlo::AddOp> {
45   using FOp = ::mlir::AddFOp;
46   using IOp = ::mlir::AddIOp;
47   using COp = ::mlir::complex::AddOp;
48 };
49 template <>
50 struct LhloToScalarOp<lmhlo::CompareOp> {
51   using FOp = ::mlir::CmpFOp;
52   using IOp = ::mlir::CmpIOp;
53 };
54 template <>
55 struct LhloToScalarOp<lmhlo::DivOp> {
56   using FOp = ::mlir::DivFOp;
57   using IOp = ::mlir::SignedDivIOp;
58 };
59 template <>
60 struct LhloToScalarOp<lmhlo::MulOp> {
61   using FOp = ::mlir::MulFOp;
62   using IOp = ::mlir::MulIOp;
63 };
64 template <>
65 struct LhloToScalarOp<lmhlo::RemOp> {
66   using FOp = ::mlir::RemFOp;
67   using IOp = ::mlir::SignedRemIOp;
68 };
69 template <>
70 struct LhloToScalarOp<lmhlo::SubOp> {
71   using FOp = ::mlir::SubFOp;
72   using IOp = ::mlir::SubIOp;
73   using COp = ::mlir::complex::SubOp;
74 };
75 
76 // Alias for the map from LHLO binary op type to STD floating-point op type.
77 template <typename LhloOp>
78 using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
79 // Alias for the map from LHLO binary op type to STD integer op type.
80 template <typename LhloOp>
81 using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
82 // Alias for the map from LHLO binary op type to STD complex op type.
83 template <typename LhloOp>
84 using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
85 
86 template <typename... Args>
87 struct MapLhloOpToStdScalarOpImpl {
88   Value operator()(Location loc, ArrayRef<Type> result_types,
89                    ArrayRef<Value> args, OpBuilder* b) {
90     return nullptr;
91   }
92 };
93 
94 template <typename StdScalarOp>
95 struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
96   Value operator()(Location loc, ArrayRef<Type> result_types,
97                    ArrayRef<Value> args, OpBuilder* b) {
98     return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
99   }
100 };
101 
102 template <typename SupportedType, typename StdScalarOp, typename... Args>
103 struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
104   Value operator()(Location loc, ArrayRef<Type> result_types,
105                    ArrayRef<Value> args, OpBuilder* b) {
106     Type element_type = getElementTypeOrSelf(args.front().getType());
107     if (element_type.isa<SupportedType>()) {
108       return b->template create<StdScalarOp>(loc, result_types, args,
109                                              mlir::None);
110     }
111     return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
112   }
113 };
114 
115 // Inserts the computation that corresponds to the body of the loop for lowered
116 // LHLO unary/binary op. Returns the value for the result.
117 template <typename LhloOpTy>
118 inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
119                                     ArrayRef<Value> args, OpBuilder* b) {
120   return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
121                                     ScalarFOp<LhloOpTy>>{}(loc, result_types,
122                                                            args, b);
123 }
124 
125 template <>
126 inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
127                                                   ArrayRef<Type> result_types,
128                                                   ArrayRef<Value> args,
129                                                   OpBuilder* b) {
130   Type element_type = getElementTypeOrSelf(args.front().getType());
131   if (element_type.isa<FloatType>()) {
132     return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
133         loc, result_types, args, b);
134   }
135   if (element_type.isa<IntegerType>()) {
136     // lmhlo.abs(x, result) ->  result = select((x > 0), x, sub(0, x))
137     Value lhs = args[0];
138     auto integer_type = element_type.dyn_cast<IntegerType>();
139 
140     Value zero_intval =
141         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
142     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
143       zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
144     }
145     auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
146                                                        lhs, zero_intval);
147     auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
148     return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
149   }
150   return nullptr;
151 }
152 template <>
153 inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
154                                                   ArrayRef<Type> result_types,
155                                                   ArrayRef<Value> args,
156                                                   OpBuilder* b) {
157   return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
158                                     FloatType, ScalarFOp<lmhlo::AddOp>,
159                                     ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
160       loc, result_types, args, b);
161 }
162 
163 template <>
164 inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
165                                                   ArrayRef<Type> result_types,
166                                                   ArrayRef<Value> args,
167                                                   OpBuilder* b) {
168   return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
169       loc, result_types, args, b);
170 }
171 
172 template <>
173 inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
174                                                     ArrayRef<Type> result_types,
175                                                     ArrayRef<Value> args,
176                                                     OpBuilder* b) {
177   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
178       loc, result_types, args, b);
179 }
180 
181 template <typename PredicateType>
182 inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
183   return llvm::None;
184 }
185 
186 template <>
187 inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
188     StringRef comparison_direction) {
189   return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
190       .Case("EQ", CmpFPredicate::OEQ)
191       .Case("NE", CmpFPredicate::UNE)
192       .Case("GE", CmpFPredicate::OGE)
193       .Case("GT", CmpFPredicate::OGT)
194       .Case("LE", CmpFPredicate::OLE)
195       .Case("LT", CmpFPredicate::OLT)
196       .Default(llvm::None);
197 }
198 
199 template <>
200 inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
201     StringRef comparison_direction) {
202   return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
203       .Case("EQ", CmpIPredicate::eq)
204       .Case("NE", CmpIPredicate::ne)
205       .Case("GE", CmpIPredicate::sge)
206       .Case("GT", CmpIPredicate::sgt)
207       .Case("LE", CmpIPredicate::sle)
208       .Case("LT", CmpIPredicate::slt)
209       .Default(llvm::None);
210 }
211 
212 template <typename CompareOpTy>
213 inline Value MapCompareOpToStdScalarOp(Location loc,
214                                        StringRef comparison_direction,
215                                        ArrayRef<Type> result_types,
216                                        ArrayRef<Value> args, OpBuilder* b) {
217   const auto& lhs = args[0];
218   const auto& rhs = args[1];
219   Type element_type = getElementTypeOrSelf(lhs.getType());
220   if (element_type.isSignlessInteger()) {
221     Optional<CmpIPredicate> predicate =
222         getCmpPredicate<CmpIPredicate>(comparison_direction);
223     assert(predicate.hasValue() && "expected valid comparison direction");
224     return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
225                                              rhs);
226   }
227   if (element_type.isa<FloatType>()) {
228     Optional<CmpFPredicate> predicate =
229         getCmpPredicate<CmpFPredicate>(comparison_direction);
230     assert(predicate.hasValue() && "expected valid comparison direction");
231     return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
232                                              rhs);
233   }
234   return nullptr;
235 }
236 
237 template <>
238 inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
239                                                    ArrayRef<Type> result_types,
240                                                    ArrayRef<Value> args,
241                                                    OpBuilder* b) {
242   return args.front();
243 }
244 
245 template <>
246 inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
247                                                   ArrayRef<Type> result_types,
248                                                   ArrayRef<Value> args,
249                                                   OpBuilder* b) {
250   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
251       loc, result_types, args, b);
252 }
253 
254 template <>
255 inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
256                                                     ArrayRef<Type> result_types,
257                                                     ArrayRef<Value> args,
258                                                     OpBuilder* b) {
259   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
260       loc, result_types, args, b);
261 }
262 
263 template <>
264 inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
265                                                    ArrayRef<Type> result_types,
266                                                    ArrayRef<Value> args,
267                                                    OpBuilder* b) {
268   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
269       loc, result_types, args, b);
270 }
271 
272 template <>
273 inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
274     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
275     OpBuilder* b) {
276   return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types,
277                                                          args, b);
278 }
279 
280 template <>
281 inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
282                                                    ArrayRef<Type> result_types,
283                                                    ArrayRef<Value> args,
284                                                    OpBuilder* b) {
285   return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args,
286                                                      b);
287 }
288 
289 template <>
290 inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
291                                                    ArrayRef<Type> result_types,
292                                                    ArrayRef<Value> args,
293                                                    OpBuilder* b) {
294   return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args,
295                                                      b);
296 }
297 
298 template <>
299 inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
300     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
301     OpBuilder* b) {
302   Type sourceType = getElementTypeOrSelf(args.front().getType());
303   Type targetType = getElementTypeOrSelf(result_types.front());
304 
305   // A boolean value is considered to be unsigned when converting to
306   // floating-point. Otherwise, it will become `-1`.
307   if (sourceType.isInteger(/*width=*/1) &&
308       mlir::UIToFPOp::areCastCompatible(sourceType, targetType)) {
309     return b->create<mlir::UIToFPOp>(loc, result_types, args, mlir::None);
310   } else if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
311     return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
312   } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
313     FloatType src = sourceType.cast<FloatType>();
314     FloatType res = targetType.cast<FloatType>();
315     if (src.getWidth() > res.getWidth()) {
316       return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
317     } else if (src.getWidth() < res.getWidth()) {
318       return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
319     }
320     // No conversion is needed for the same width floats
321     return args.front();
322   }
323   if (targetType.isInteger(/*width=*/1)) {
324     // When casting to bool, we need to compare whether the value is equal to
325     // zero.
326     if (sourceType.isSignlessInteger()) {
327       Value zero_intval = b->create<::mlir::ConstantIntOp>(
328           loc, 0, sourceType.cast<IntegerType>().getWidth());
329       if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
330         zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
331       }
332       return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
333                                      zero_intval);
334     } else if (sourceType.isa<FloatType>()) {
335       Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
336       if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
337         zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
338       }
339       return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(),
340                                      zero);
341     }
342   }
343   if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
344     IntegerType src = sourceType.cast<IntegerType>();
345     IntegerType res = targetType.cast<IntegerType>();
346     if (src.getWidth() > res.getWidth()) {
347       return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
348     } else if (src.getWidth() == 1) {
349       // Special case boolean values, so they get casted to `1` instead of `-1`.
350       return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
351                                             mlir::None);
352     } else if (src.getWidth() < res.getWidth()) {
353       return b->create<mlir::SignExtendIOp>(loc, result_types, args,
354                                             mlir::None);
355     }
356     // No conversion is needed for the same width integers
357     return args.front();
358   }
359   if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
360     return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
361   }
362   return nullptr;
363 }
364 
365 template <>
366 inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
367                                                   ArrayRef<Type> result_types,
368                                                   ArrayRef<Value> args,
369                                                   OpBuilder* b) {
370   // Dot Op converter from lhlo to affine only accepts float and integer types.
371   const auto& lhs = args[0];
372   const auto& rhs = args[1];
373   const auto& result = args[2];
374   Type element_type = lhs.getType();
375   if (element_type.isa<FloatType>()) {
376     Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
377         loc, result_types, {lhs, rhs}, b);
378     return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
379         loc, result_types, {float_mul, result}, b);
380   }
381   if (element_type.isa<IntegerType>()) {
382     Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
383         loc, result_types, {lhs, rhs}, b);
384     return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
385         loc, result_types, {int_mul, result}, b);
386   }
387   return nullptr;
388 }
389 
390 template <>
391 inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
392                                                   ArrayRef<Type> result_types,
393                                                   ArrayRef<Value> args,
394                                                   OpBuilder* b) {
395   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
396       loc, result_types, args, b);
397 }
398 
399 template <>
400 inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
401                                                   ArrayRef<Type> result_types,
402                                                   ArrayRef<Value> args,
403                                                   OpBuilder* b) {
404   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
405       loc, result_types, args, b);
406 }
407 
408 template <>
409 inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
410                                                     ArrayRef<Type> result_types,
411                                                     ArrayRef<Value> args,
412                                                     OpBuilder* b) {
413   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
414       loc, result_types, args, b);
415 }
416 
417 template <>
418 inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
419     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
420     OpBuilder* b) {
421   if (args[0].getType().isa<FloatType>()) {
422     auto pos_inf = APFloat::getInf(
423         args[0].getType().cast<FloatType>().getFloatSemantics());
424     auto const_pos_inf =
425         b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
426     Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
427     return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
428                                      const_pos_inf);
429   }
430   return nullptr;
431 }
432 
433 /// Implements the conversion of HLO op to scalar op (to use within region of a
434 /// linalg.generic op) for compare-select style operations like min/max.
435 template <typename... Args>
436 struct CompareSelectOpToStdScalarOp {
437   static Value map(Location loc, StringRef comparison_direction,
438                    ArrayRef<Type> result_types, ArrayRef<Value> args,
439                    OpBuilder* b) {
440     return nullptr;
441   }
442 };
443 
444 /// Specialization which allows converting to a comparison operation in standard
445 /// dialect with a given predicate based on the element type of the operand.
446 template <typename SupportedType, typename StdCompareOp, typename Predicate,
447           typename... Args>
448 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
449                                     Args...> {
450   static Value map(Location loc, StringRef comparison_direction,
451                    ArrayRef<Type> result_types, ArrayRef<Value> args,
452                    OpBuilder* b) {
453     Type element_type = getElementTypeOrSelf(args.front().getType());
454     if (element_type.isa<SupportedType>()) {
455       auto predicate = getCmpPredicate<Predicate>(comparison_direction);
456       assert(predicate.hasValue() && "expected valid comparison direction");
457       auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
458                                                   args[0], args[1]);
459       return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
460     }
461     return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
462                                                       result_types, args, b);
463   }
464 };
465 
466 template <>
467 inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
468                                                   ArrayRef<Type> result_types,
469                                                   ArrayRef<Value> args,
470                                                   OpBuilder* b) {
471   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
472       loc, result_types, args, b);
473 }
474 
475 inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
476                                     OpBuilder* b) {
477   Type element_type = getElementTypeOrSelf(args.front().getType());
478   if (auto float_type = element_type.dyn_cast<FloatType>()) {
479     Value isnan =
480         b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]);
481 
482     auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics());
483     Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type);
484     if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) {
485       nan = b->create<::mlir::SplatOp>(loc, vec_type, nan);
486     }
487     v = b->create<mlir::SelectOp>(loc, isnan, nan, v);
488   }
489   return v;
490 }
491 
492 template <>
493 inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
494     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
495     OpBuilder* b) {
496   auto ty = result_types.front().cast<FloatType>();
497   Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
498   Value x = args.front();
499   Value neg_x = b->create<NegFOp>(loc, x);
500   Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x);
501   Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x);
502   return b->create<DivFOp>(loc, one, one_add_exp_neg_x);
503 }
504 
505 template <>
506 inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
507                                                     ArrayRef<Type> result_types,
508                                                     ArrayRef<Value> args,
509                                                     OpBuilder* b) {
510   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
511       loc, result_types, args, b);
512 }
513 
514 template <>
515 inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
516                                                   ArrayRef<Type> result_types,
517                                                   ArrayRef<Value> args,
518                                                   OpBuilder* b) {
519   return LhloAlwaysPropagateNaN(
520       CompareSelectOpToStdScalarOp<
521           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
522           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
523                                                            result_types, args,
524                                                            b),
525       args, loc, b);
526 }
527 
528 template <>
529 inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
530                                                   ArrayRef<Type> result_types,
531                                                   ArrayRef<Value> args,
532                                                   OpBuilder* b) {
533   return LhloAlwaysPropagateNaN(
534       CompareSelectOpToStdScalarOp<
535           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
536           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
537                                                            result_types, args,
538                                                            b),
539       args, loc, b);
540 }
541 
542 template <>
543 inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
544                                                     ArrayRef<Type> result_types,
545                                                     ArrayRef<Value> args,
546                                                     OpBuilder* b) {
547   assert(args.size() == 3 && "expected 3 arguments");
548   Value lb = args[0];
549   Value x = args[1];
550   Value ub = args[2];
551 
552   // clamp(lb, x, ub) = max(min(x, ub), lb)
553   Value min_x_ub =
554       MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b);
555   return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb},
556                                               b);
557 }
558 
559 template <>
560 inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
561                                                   ArrayRef<Type> result_types,
562                                                   ArrayRef<Value> args,
563                                                   OpBuilder* b) {
564   Type element_type = getElementTypeOrSelf(args.front().getType());
565   if (element_type.isa<FloatType>()) {
566     return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
567         loc, result_types, args, b);
568   }
569   if (element_type.isa<IntegerType>()) {
570     // lmhlo.neg(x, result) -> result = sub(0, x)
571     Value lhs = args[0];
572     auto integer_type = element_type.dyn_cast<IntegerType>();
573 
574     Value zero_intval =
575         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
576     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
577       zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
578     }
579     return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
580   }
581   return nullptr;
582 }
583 
584 template <>
585 inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
586                                                   ArrayRef<Type> result_types,
587                                                   ArrayRef<Value> args,
588                                                   OpBuilder* b) {
589   Type element_type = getElementTypeOrSelf(args.front().getType());
590   if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
591     // lmhlo.not(x) -> x ^ -1
592     Value all_ones =
593         b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
594     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
595       all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
596     }
597     return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
598   }
599   return nullptr;
600 }
601 
602 template <>
603 inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
604                                                  ArrayRef<Type> result_types,
605                                                  ArrayRef<Value> args,
606                                                  OpBuilder* b) {
607   return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
608       loc, result_types, args, b);
609 }
610 
611 template <>
612 inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
613                                                     ArrayRef<Type> result_types,
614                                                     ArrayRef<Value> args,
615                                                     OpBuilder* b) {
616   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
617       loc, result_types, args, b);
618 }
619 
620 template <>
621 inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
622                                                   ArrayRef<Type> result_types,
623                                                   ArrayRef<Value> args,
624                                                   OpBuilder* b) {
625   lmhlo::PowOp::Adaptor adaptor(args);
626   auto lb = ImplicitLocOpBuilder(loc, *b);
627   // Floating point can use std::powf
628   auto result_type = result_types.front();
629   if (result_type.isa<::mlir::FloatType>())
630     return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
631                                                               args, b);
632 
633   assert(result_type.isa<::mlir::IntegerType>() &&
634          "only float and integer `pow` is supported right now");
635 
636   // Exponentiation by squaring:
637   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
638   Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1));
639   Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0));
640   Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1));
641   Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2));
642   Value step = lb.create<ConstantIndexOp>(1);
643   Value lowerBound = lb.create<ConstantIndexOp>(0);
644   // Everything else would overflow for any exponent > 1, as 2^64
645   // is the larget possible exponent for a 64-bit integer, and
646   // that's 1 << 6.
647   Value upperBound = lb.create<ConstantIndexOp>(6);
648   auto original_base = adaptor.lhs();
649   auto original_exponent = adaptor.rhs();
650 
651   Value accum =
652       lb.create<scf::ForOp>(
653             lowerBound, upperBound, step,
654             SmallVector<Value>({one, original_base, original_exponent}),
655             [&](OpBuilder& b, Location, Value v, ValueRange iters) {
656               Value accum = iters[0];
657               Value base = iters[1];
658               Value exponent = iters[2];
659 
660               Value condition = b.create<CmpIOp>(
661                   loc, CmpIPredicate::eq,
662                   b.create<::mlir::AndOp>(loc, exponent, one), one);
663               Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base);
664               accum =
665                   b.create<::mlir::SelectOp>(loc, condition, multiplied, accum);
666               base = b.create<::mlir::MulIOp>(loc, base, base);
667               exponent =
668                   b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one);
669               b.create<scf::YieldOp>(
670                   loc, SmallVector<Value>({accum, base, exponent}));
671             })
672           .getResult(0);
673 
674   Value rhs_is_even = lb.create<CmpIOp>(
675       CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero);
676   Value rhs_is_negative =
677       lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero);
678   Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one);
679   Value lhs_is_neg_one =
680       lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one);
681 
682   // The accum is correct when the rhs is non-negative. When rhs is
683   // negative, we return 0 for integer, with the exception of lhs values of 1
684   // and -1 which have integer results for negative exponents. Specifically, the
685   // calulation is the following:
686   //
687   // - Return accum if the rhs is not negative.
688   // - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
689   // - Return 1 if lhs is 1.
690   // - Else return 0.
691   Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero);
692   Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>(
693       lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one),
694       if_lhs_is_one);
695   return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum);
696 }
697 
698 template <>
699 inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
700     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
701     OpBuilder* b) {
702   return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
703                                                         b);
704 }
705 
706 template <>
707 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
708     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
709     OpBuilder* b) {
710   return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
711       loc, result_types, args, b);
712 }
713 
714 template <>
715 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
716     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
717     OpBuilder* b) {
718   return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
719       loc, result_types, args, b);
720 }
721 
722 template <>
723 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
724     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
725     OpBuilder* b) {
726   return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
727       loc, result_types, args, b);
728 }
729 
730 template <>
731 inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
732                                                    ArrayRef<Type> result_types,
733                                                    ArrayRef<Value> args,
734                                                    OpBuilder* b) {
735   Type element_type = getElementTypeOrSelf(args.front().getType());
736   if (auto float_type = element_type.dyn_cast<FloatType>()) {
737     bool ignored;
738     APFloat zero_apfloat(0.0f);
739     zero_apfloat.convert(float_type.getFloatSemantics(),
740                          APFloat::rmNearestTiesToEven, &ignored);
741     Value zero =
742         b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
743     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
744       zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
745     }
746     Value ne0_i1 =
747         b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
748     Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType());
749     Value copy_sign =
750         b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
751     auto is_nan =
752         b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
753     return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
754   } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
755     // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
756     Value zero =
757         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
758     Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
759         loc, integer_type.getWidth() - 1, integer_type.getWidth());
760     Value one =
761         b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
762     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
763       zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
764       bitwidth_minus_one =
765           b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
766       one = b->create<::mlir::SplatOp>(loc, vec_type, one);
767     }
768     Value cmp =
769         b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
770     Value ashr =
771         b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
772     Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
773     return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
774   }
775   return nullptr;
776 }
777 
778 template <>
779 inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
780                                                    ArrayRef<Type> result_types,
781                                                    ArrayRef<Value> args,
782                                                    OpBuilder* b) {
783   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
784       loc, result_types, args, b);
785 }
786 
787 template <>
788 inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
789                                                   ArrayRef<Type> result_types,
790                                                   ArrayRef<Value> args,
791                                                   OpBuilder* b) {
792   return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
793                                     FloatType, ScalarFOp<lmhlo::SubOp>,
794                                     ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
795       loc, result_types, args, b);
796 }
797 
798 template <>
799 inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
800                                                    ArrayRef<Type> result_types,
801                                                    ArrayRef<Value> args,
802                                                    OpBuilder* b) {
803   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
804       loc, result_types, args, b);
805 }
806 
807 template <>
808 inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
809                                                   ArrayRef<Type> result_types,
810                                                   ArrayRef<Value> args,
811                                                   OpBuilder* b) {
812   return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
813       loc, result_types, args, b);
814 }
815 
816 }  // namespace impl
817 
818 struct HloOpToStdScalarOp {
819   // Implementation for LHLO ops except lmhlo::CompareOp.
820   template <typename HloOpTy, typename LhloOpTy = HloOpTy,
821             typename = std::enable_if_t<
822                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
823                 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
824                              std::false_type>::value>>
825   static Value map(HloOpTy op, ArrayRef<Type> result_types,
826                    ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
827     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
828                                                   args, b);
829   }
830 
831   // Implementation for HLO ops except mhlo::CompareOp.
832   template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>,
833             typename = std::enable_if_t<
834                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
835                 !std::is_same<LhloOpTy, std::false_type>::value>>
836   static Value map(HloOpTy op, ArrayRef<Type> result_types,
837                    ArrayRef<Value> args, OpBuilder* b, int i = 0) {
838     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
839                                                   args, b);
840   }
841 
842   // Implementation for lmhlo::CompareOp.
843   template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
844                                    LhloOpTy, lmhlo::CompareOp>::value>>
845   static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
846                    ArrayRef<Value> args, OpBuilder* b) {
847     auto comparison_direction = op.comparison_direction();
848     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
849         op.getLoc(), comparison_direction, result_types, args, b);
850   }
851 
852   // Implementation for mhlo::CompareOp.
853   template <typename HloOpTy,
854             typename =
855                 std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
856   static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
857                    ArrayRef<Value> args, OpBuilder* b) {
858     auto comparison_direction = op.comparison_direction();
859     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
860         op.getLoc(), comparison_direction, result_types, args, b);
861   }
862 
863   // Implementation for LHLO ops except lmhlo::CompareOp.
864   template <typename LhloOpTy,
865             typename = std::enable_if_t<
866                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
867                 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
868                              std::false_type>::value>>
869   static Value map(Location loc, ArrayRef<Type> result_types,
870                    ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
871     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
872   }
873 
874   // Implementation for lmhlo::CompareOp.
875   template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
876                                    LhloOpTy, lmhlo::CompareOp>::value>>
877   static Value map(Location loc, StringRef comparison_direction,
878                    ArrayRef<Type> result_types, ArrayRef<Value> args,
879                    OpBuilder* b) {
880     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
881         loc, comparison_direction, result_types, args, b);
882   }
883 };
884 
885 }  // namespace lmhlo
886 }  // namespace mlir
887 
888 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
889