1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include "llvm/ADT/iterator_range.h" 23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 25 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" 26 #include "mlir/Dialect/Complex/IR/Complex.h" 27 #include "mlir/Dialect/Math/IR/Math.h" 28 #include "mlir/Dialect/SCF/SCF.h" 29 #include "mlir/Dialect/StandardOps/IR/Ops.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/ImplicitLocOpBuilder.h" 32 #include "mlir/IR/TypeUtilities.h" 33 34 namespace mlir { 35 namespace lmhlo { 36 namespace impl { 37 38 // A struct to map LhloBinaryOpTy type to the corresponding floating-point and 39 // integer scalar operation types. 40 template <typename LhloBinaryOpTy> 41 struct LhloToScalarOp; 42 43 template <> 44 struct LhloToScalarOp<lmhlo::AddOp> { 45 using FOp = ::mlir::AddFOp; 46 using IOp = ::mlir::AddIOp; 47 using COp = ::mlir::complex::AddOp; 48 }; 49 template <> 50 struct LhloToScalarOp<lmhlo::CompareOp> { 51 using FOp = ::mlir::CmpFOp; 52 using IOp = ::mlir::CmpIOp; 53 }; 54 template <> 55 struct LhloToScalarOp<lmhlo::DivOp> { 56 using FOp = ::mlir::DivFOp; 57 using IOp = ::mlir::SignedDivIOp; 58 }; 59 template <> 60 struct LhloToScalarOp<lmhlo::MulOp> { 61 using FOp = ::mlir::MulFOp; 62 using IOp = ::mlir::MulIOp; 63 }; 64 template <> 65 struct LhloToScalarOp<lmhlo::RemOp> { 66 using FOp = ::mlir::RemFOp; 67 using IOp = ::mlir::SignedRemIOp; 68 }; 69 template <> 70 struct LhloToScalarOp<lmhlo::SubOp> { 71 using FOp = ::mlir::SubFOp; 72 using IOp = ::mlir::SubIOp; 73 using COp = ::mlir::complex::SubOp; 74 }; 75 76 // Alias for the map from LHLO binary op type to STD floating-point op type. 77 template <typename LhloOp> 78 using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp; 79 // Alias for the map from LHLO binary op type to STD integer op type. 80 template <typename LhloOp> 81 using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp; 82 // Alias for the map from LHLO binary op type to STD complex op type. 83 template <typename LhloOp> 84 using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp; 85 86 template <typename... Args> 87 struct MapLhloOpToStdScalarOpImpl { 88 Value operator()(Location loc, ArrayRef<Type> result_types, 89 ArrayRef<Value> args, OpBuilder* b) { 90 return nullptr; 91 } 92 }; 93 94 template <typename StdScalarOp> 95 struct MapLhloOpToStdScalarOpImpl<StdScalarOp> { 96 Value operator()(Location loc, ArrayRef<Type> result_types, 97 ArrayRef<Value> args, OpBuilder* b) { 98 return b->template create<StdScalarOp>(loc, result_types, args, mlir::None); 99 } 100 }; 101 102 template <typename SupportedType, typename StdScalarOp, typename... Args> 103 struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> { 104 Value operator()(Location loc, ArrayRef<Type> result_types, 105 ArrayRef<Value> args, OpBuilder* b) { 106 Type element_type = getElementTypeOrSelf(args.front().getType()); 107 if (element_type.isa<SupportedType>()) { 108 return b->template create<StdScalarOp>(loc, result_types, args, 109 mlir::None); 110 } 111 return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b); 112 } 113 }; 114 115 // Inserts the computation that corresponds to the body of the loop for lowered 116 // LHLO unary/binary op. Returns the value for the result. 117 template <typename LhloOpTy> 118 inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types, 119 ArrayRef<Value> args, OpBuilder* b) { 120 return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType, 121 ScalarFOp<LhloOpTy>>{}(loc, result_types, 122 args, b); 123 } 124 125 template <> 126 inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc, 127 ArrayRef<Type> result_types, 128 ArrayRef<Value> args, 129 OpBuilder* b) { 130 Type element_type = getElementTypeOrSelf(args.front().getType()); 131 if (element_type.isa<FloatType>()) { 132 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}( 133 loc, result_types, args, b); 134 } 135 if (element_type.isa<IntegerType>()) { 136 // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) 137 Value lhs = args[0]; 138 auto integer_type = element_type.dyn_cast<IntegerType>(); 139 140 Value zero_intval = 141 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 142 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 143 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 144 } 145 auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, 146 lhs, zero_intval); 147 auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); 148 return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); 149 } 150 return nullptr; 151 } 152 template <> 153 inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc, 154 ArrayRef<Type> result_types, 155 ArrayRef<Value> args, 156 OpBuilder* b) { 157 return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>, 158 FloatType, ScalarFOp<lmhlo::AddOp>, 159 ComplexType, ScalarCOp<lmhlo::AddOp>>{}( 160 loc, result_types, args, b); 161 } 162 163 template <> 164 inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc, 165 ArrayRef<Type> result_types, 166 ArrayRef<Value> args, 167 OpBuilder* b) { 168 return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}( 169 loc, result_types, args, b); 170 } 171 172 template <> 173 inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc, 174 ArrayRef<Type> result_types, 175 ArrayRef<Value> args, 176 OpBuilder* b) { 177 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}( 178 loc, result_types, args, b); 179 } 180 181 template <typename PredicateType> 182 inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) { 183 return llvm::None; 184 } 185 186 template <> 187 inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>( 188 StringRef comparison_direction) { 189 return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction) 190 .Case("EQ", CmpFPredicate::OEQ) 191 .Case("NE", CmpFPredicate::UNE) 192 .Case("GE", CmpFPredicate::OGE) 193 .Case("GT", CmpFPredicate::OGT) 194 .Case("LE", CmpFPredicate::OLE) 195 .Case("LT", CmpFPredicate::OLT) 196 .Default(llvm::None); 197 } 198 199 template <> 200 inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>( 201 StringRef comparison_direction) { 202 return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction) 203 .Case("EQ", CmpIPredicate::eq) 204 .Case("NE", CmpIPredicate::ne) 205 .Case("GE", CmpIPredicate::sge) 206 .Case("GT", CmpIPredicate::sgt) 207 .Case("LE", CmpIPredicate::sle) 208 .Case("LT", CmpIPredicate::slt) 209 .Default(llvm::None); 210 } 211 212 template <typename CompareOpTy> 213 inline Value MapCompareOpToStdScalarOp(Location loc, 214 StringRef comparison_direction, 215 ArrayRef<Type> result_types, 216 ArrayRef<Value> args, OpBuilder* b) { 217 const auto& lhs = args[0]; 218 const auto& rhs = args[1]; 219 Type element_type = getElementTypeOrSelf(lhs.getType()); 220 if (element_type.isSignlessInteger()) { 221 Optional<CmpIPredicate> predicate = 222 getCmpPredicate<CmpIPredicate>(comparison_direction); 223 assert(predicate.hasValue() && "expected valid comparison direction"); 224 return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs, 225 rhs); 226 } 227 if (element_type.isa<FloatType>()) { 228 Optional<CmpFPredicate> predicate = 229 getCmpPredicate<CmpFPredicate>(comparison_direction); 230 assert(predicate.hasValue() && "expected valid comparison direction"); 231 return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs, 232 rhs); 233 } 234 return nullptr; 235 } 236 237 template <> 238 inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc, 239 ArrayRef<Type> result_types, 240 ArrayRef<Value> args, 241 OpBuilder* b) { 242 return args.front(); 243 } 244 245 template <> 246 inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc, 247 ArrayRef<Type> result_types, 248 ArrayRef<Value> args, 249 OpBuilder* b) { 250 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}( 251 loc, result_types, args, b); 252 } 253 254 template <> 255 inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc, 256 ArrayRef<Type> result_types, 257 ArrayRef<Value> args, 258 OpBuilder* b) { 259 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}( 260 loc, result_types, args, b); 261 } 262 263 template <> 264 inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc, 265 ArrayRef<Type> result_types, 266 ArrayRef<Value> args, 267 OpBuilder* b) { 268 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}( 269 loc, result_types, args, b); 270 } 271 272 template <> 273 inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>( 274 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 275 OpBuilder* b) { 276 return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types, 277 args, b); 278 } 279 280 template <> 281 inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc, 282 ArrayRef<Type> result_types, 283 ArrayRef<Value> args, 284 OpBuilder* b) { 285 return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args, 286 b); 287 } 288 289 template <> 290 inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc, 291 ArrayRef<Type> result_types, 292 ArrayRef<Value> args, 293 OpBuilder* b) { 294 return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args, 295 b); 296 } 297 298 template <> 299 inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>( 300 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 301 OpBuilder* b) { 302 Type sourceType = getElementTypeOrSelf(args.front().getType()); 303 Type targetType = getElementTypeOrSelf(result_types.front()); 304 305 // A boolean value is considered to be unsigned when converting to 306 // floating-point. Otherwise, it will become `-1`. 307 if (sourceType.isInteger(/*width=*/1) && 308 mlir::UIToFPOp::areCastCompatible(sourceType, targetType)) { 309 return b->create<mlir::UIToFPOp>(loc, result_types, args, mlir::None); 310 } else if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { 311 return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None); 312 } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { 313 FloatType src = sourceType.cast<FloatType>(); 314 FloatType res = targetType.cast<FloatType>(); 315 if (src.getWidth() > res.getWidth()) { 316 return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None); 317 } else if (src.getWidth() < res.getWidth()) { 318 return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None); 319 } 320 // No conversion is needed for the same width floats 321 return args.front(); 322 } 323 if (targetType.isInteger(/*width=*/1)) { 324 // When casting to bool, we need to compare whether the value is equal to 325 // zero. 326 if (sourceType.isSignlessInteger()) { 327 Value zero_intval = b->create<::mlir::ConstantIntOp>( 328 loc, 0, sourceType.cast<IntegerType>().getWidth()); 329 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 330 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 331 } 332 return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(), 333 zero_intval); 334 } else if (sourceType.isa<FloatType>()) { 335 Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0)); 336 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 337 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 338 } 339 return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(), 340 zero); 341 } 342 } 343 if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { 344 IntegerType src = sourceType.cast<IntegerType>(); 345 IntegerType res = targetType.cast<IntegerType>(); 346 if (src.getWidth() > res.getWidth()) { 347 return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None); 348 } else if (src.getWidth() == 1) { 349 // Special case boolean values, so they get casted to `1` instead of `-1`. 350 return b->create<mlir::ZeroExtendIOp>(loc, result_types, args, 351 mlir::None); 352 } else if (src.getWidth() < res.getWidth()) { 353 return b->create<mlir::SignExtendIOp>(loc, result_types, args, 354 mlir::None); 355 } 356 // No conversion is needed for the same width integers 357 return args.front(); 358 } 359 if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { 360 return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None); 361 } 362 return nullptr; 363 } 364 365 template <> 366 inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc, 367 ArrayRef<Type> result_types, 368 ArrayRef<Value> args, 369 OpBuilder* b) { 370 // Dot Op converter from lhlo to affine only accepts float and integer types. 371 const auto& lhs = args[0]; 372 const auto& rhs = args[1]; 373 const auto& result = args[2]; 374 Type element_type = lhs.getType(); 375 if (element_type.isa<FloatType>()) { 376 Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}( 377 loc, result_types, {lhs, rhs}, b); 378 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}( 379 loc, result_types, {float_mul, result}, b); 380 } 381 if (element_type.isa<IntegerType>()) { 382 Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}( 383 loc, result_types, {lhs, rhs}, b); 384 return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}( 385 loc, result_types, {int_mul, result}, b); 386 } 387 return nullptr; 388 } 389 390 template <> 391 inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc, 392 ArrayRef<Type> result_types, 393 ArrayRef<Value> args, 394 OpBuilder* b) { 395 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}( 396 loc, result_types, args, b); 397 } 398 399 template <> 400 inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc, 401 ArrayRef<Type> result_types, 402 ArrayRef<Value> args, 403 OpBuilder* b) { 404 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}( 405 loc, result_types, args, b); 406 } 407 408 template <> 409 inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc, 410 ArrayRef<Type> result_types, 411 ArrayRef<Value> args, 412 OpBuilder* b) { 413 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}( 414 loc, result_types, args, b); 415 } 416 417 template <> 418 inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>( 419 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 420 OpBuilder* b) { 421 if (args[0].getType().isa<FloatType>()) { 422 auto pos_inf = APFloat::getInf( 423 args[0].getType().cast<FloatType>().getFloatSemantics()); 424 auto const_pos_inf = 425 b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf)); 426 Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); 427 return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, 428 const_pos_inf); 429 } 430 return nullptr; 431 } 432 433 /// Implements the conversion of HLO op to scalar op (to use within region of a 434 /// linalg.generic op) for compare-select style operations like min/max. 435 template <typename... Args> 436 struct CompareSelectOpToStdScalarOp { 437 static Value map(Location loc, StringRef comparison_direction, 438 ArrayRef<Type> result_types, ArrayRef<Value> args, 439 OpBuilder* b) { 440 return nullptr; 441 } 442 }; 443 444 /// Specialization which allows converting to a comparison operation in standard 445 /// dialect with a given predicate based on the element type of the operand. 446 template <typename SupportedType, typename StdCompareOp, typename Predicate, 447 typename... Args> 448 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, 449 Args...> { 450 static Value map(Location loc, StringRef comparison_direction, 451 ArrayRef<Type> result_types, ArrayRef<Value> args, 452 OpBuilder* b) { 453 Type element_type = getElementTypeOrSelf(args.front().getType()); 454 if (element_type.isa<SupportedType>()) { 455 auto predicate = getCmpPredicate<Predicate>(comparison_direction); 456 assert(predicate.hasValue() && "expected valid comparison direction"); 457 auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(), 458 args[0], args[1]); 459 return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); 460 } 461 return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction, 462 result_types, args, b); 463 } 464 }; 465 466 template <> 467 inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc, 468 ArrayRef<Type> result_types, 469 ArrayRef<Value> args, 470 OpBuilder* b) { 471 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}( 472 loc, result_types, args, b); 473 } 474 475 inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc, 476 OpBuilder* b) { 477 Type element_type = getElementTypeOrSelf(args.front().getType()); 478 if (auto float_type = element_type.dyn_cast<FloatType>()) { 479 Value isnan = 480 b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]); 481 482 auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics()); 483 Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type); 484 if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) { 485 nan = b->create<::mlir::SplatOp>(loc, vec_type, nan); 486 } 487 v = b->create<mlir::SelectOp>(loc, isnan, nan, v); 488 } 489 return v; 490 } 491 492 template <> 493 inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>( 494 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 495 OpBuilder* b) { 496 auto ty = result_types.front().cast<FloatType>(); 497 Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0)); 498 Value x = args.front(); 499 Value neg_x = b->create<NegFOp>(loc, x); 500 Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x); 501 Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x); 502 return b->create<DivFOp>(loc, one, one_add_exp_neg_x); 503 } 504 505 template <> 506 inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc, 507 ArrayRef<Type> result_types, 508 ArrayRef<Value> args, 509 OpBuilder* b) { 510 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}( 511 loc, result_types, args, b); 512 } 513 514 template <> 515 inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc, 516 ArrayRef<Type> result_types, 517 ArrayRef<Value> args, 518 OpBuilder* b) { 519 return LhloAlwaysPropagateNaN( 520 CompareSelectOpToStdScalarOp< 521 IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, 522 ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", 523 result_types, args, 524 b), 525 args, loc, b); 526 } 527 528 template <> 529 inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc, 530 ArrayRef<Type> result_types, 531 ArrayRef<Value> args, 532 OpBuilder* b) { 533 return LhloAlwaysPropagateNaN( 534 CompareSelectOpToStdScalarOp< 535 IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, 536 ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", 537 result_types, args, 538 b), 539 args, loc, b); 540 } 541 542 template <> 543 inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc, 544 ArrayRef<Type> result_types, 545 ArrayRef<Value> args, 546 OpBuilder* b) { 547 assert(args.size() == 3 && "expected 3 arguments"); 548 Value lb = args[0]; 549 Value x = args[1]; 550 Value ub = args[2]; 551 552 // clamp(lb, x, ub) = max(min(x, ub), lb) 553 Value min_x_ub = 554 MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b); 555 return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb}, 556 b); 557 } 558 559 template <> 560 inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc, 561 ArrayRef<Type> result_types, 562 ArrayRef<Value> args, 563 OpBuilder* b) { 564 Type element_type = getElementTypeOrSelf(args.front().getType()); 565 if (element_type.isa<FloatType>()) { 566 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}( 567 loc, result_types, args, b); 568 } 569 if (element_type.isa<IntegerType>()) { 570 // lmhlo.neg(x, result) -> result = sub(0, x) 571 Value lhs = args[0]; 572 auto integer_type = element_type.dyn_cast<IntegerType>(); 573 574 Value zero_intval = 575 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 576 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 577 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 578 } 579 return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); 580 } 581 return nullptr; 582 } 583 584 template <> 585 inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc, 586 ArrayRef<Type> result_types, 587 ArrayRef<Value> args, 588 OpBuilder* b) { 589 Type element_type = getElementTypeOrSelf(args.front().getType()); 590 if (auto integer_type = element_type.dyn_cast<IntegerType>()) { 591 // lmhlo.not(x) -> x ^ -1 592 Value all_ones = 593 b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); 594 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 595 all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); 596 } 597 return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); 598 } 599 return nullptr; 600 } 601 602 template <> 603 inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc, 604 ArrayRef<Type> result_types, 605 ArrayRef<Value> args, 606 OpBuilder* b) { 607 return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}( 608 loc, result_types, args, b); 609 } 610 611 template <> 612 inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, 613 ArrayRef<Type> result_types, 614 ArrayRef<Value> args, 615 OpBuilder* b) { 616 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}( 617 loc, result_types, args, b); 618 } 619 620 template <> 621 inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc, 622 ArrayRef<Type> result_types, 623 ArrayRef<Value> args, 624 OpBuilder* b) { 625 lmhlo::PowOp::Adaptor adaptor(args); 626 auto lb = ImplicitLocOpBuilder(loc, *b); 627 // Floating point can use std::powf 628 auto result_type = result_types.front(); 629 if (result_type.isa<::mlir::FloatType>()) 630 return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, 631 args, b); 632 633 assert(result_type.isa<::mlir::IntegerType>() && 634 "only float and integer `pow` is supported right now"); 635 636 // Exponentiation by squaring: 637 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring; 638 Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1)); 639 Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0)); 640 Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1)); 641 Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2)); 642 Value step = lb.create<ConstantIndexOp>(1); 643 Value lowerBound = lb.create<ConstantIndexOp>(0); 644 // Everything else would overflow for any exponent > 1, as 2^64 645 // is the larget possible exponent for a 64-bit integer, and 646 // that's 1 << 6. 647 Value upperBound = lb.create<ConstantIndexOp>(6); 648 auto original_base = adaptor.lhs(); 649 auto original_exponent = adaptor.rhs(); 650 651 Value accum = 652 lb.create<scf::ForOp>( 653 lowerBound, upperBound, step, 654 SmallVector<Value>({one, original_base, original_exponent}), 655 [&](OpBuilder& b, Location, Value v, ValueRange iters) { 656 Value accum = iters[0]; 657 Value base = iters[1]; 658 Value exponent = iters[2]; 659 660 Value condition = b.create<CmpIOp>( 661 loc, CmpIPredicate::eq, 662 b.create<::mlir::AndOp>(loc, exponent, one), one); 663 Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base); 664 accum = 665 b.create<::mlir::SelectOp>(loc, condition, multiplied, accum); 666 base = b.create<::mlir::MulIOp>(loc, base, base); 667 exponent = 668 b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one); 669 b.create<scf::YieldOp>( 670 loc, SmallVector<Value>({accum, base, exponent})); 671 }) 672 .getResult(0); 673 674 Value rhs_is_even = lb.create<CmpIOp>( 675 CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero); 676 Value rhs_is_negative = 677 lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero); 678 Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one); 679 Value lhs_is_neg_one = 680 lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one); 681 682 // The accum is correct when the rhs is non-negative. When rhs is 683 // negative, we return 0 for integer, with the exception of lhs values of 1 684 // and -1 which have integer results for negative exponents. Specifically, the 685 // calulation is the following: 686 // 687 // - Return accum if the rhs is not negative. 688 // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. 689 // - Return 1 if lhs is 1. 690 // - Else return 0. 691 Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero); 692 Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>( 693 lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one), 694 if_lhs_is_one); 695 return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum); 696 } 697 698 template <> 699 inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>( 700 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 701 OpBuilder* b) { 702 return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, 703 b); 704 } 705 706 template <> 707 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>( 708 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 709 OpBuilder* b) { 710 return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}( 711 loc, result_types, args, b); 712 } 713 714 template <> 715 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>( 716 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 717 OpBuilder* b) { 718 return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}( 719 loc, result_types, args, b); 720 } 721 722 template <> 723 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>( 724 Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, 725 OpBuilder* b) { 726 return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}( 727 loc, result_types, args, b); 728 } 729 730 template <> 731 inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc, 732 ArrayRef<Type> result_types, 733 ArrayRef<Value> args, 734 OpBuilder* b) { 735 Type element_type = getElementTypeOrSelf(args.front().getType()); 736 if (auto float_type = element_type.dyn_cast<FloatType>()) { 737 bool ignored; 738 APFloat zero_apfloat(0.0f); 739 zero_apfloat.convert(float_type.getFloatSemantics(), 740 APFloat::rmNearestTiesToEven, &ignored); 741 Value zero = 742 b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type); 743 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 744 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 745 } 746 Value ne0_i1 = 747 b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero); 748 Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType()); 749 Value copy_sign = 750 b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]); 751 auto is_nan = 752 b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]); 753 return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign); 754 } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { 755 // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) 756 Value zero = 757 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 758 Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( 759 loc, integer_type.getWidth() - 1, integer_type.getWidth()); 760 Value one = 761 b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); 762 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 763 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 764 bitwidth_minus_one = 765 b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); 766 one = b->create<::mlir::SplatOp>(loc, vec_type, one); 767 } 768 Value cmp = 769 b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); 770 Value ashr = 771 b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); 772 Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); 773 return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); 774 } 775 return nullptr; 776 } 777 778 template <> 779 inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc, 780 ArrayRef<Type> result_types, 781 ArrayRef<Value> args, 782 OpBuilder* b) { 783 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}( 784 loc, result_types, args, b); 785 } 786 787 template <> 788 inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc, 789 ArrayRef<Type> result_types, 790 ArrayRef<Value> args, 791 OpBuilder* b) { 792 return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>, 793 FloatType, ScalarFOp<lmhlo::SubOp>, 794 ComplexType, ScalarCOp<lmhlo::SubOp>>{}( 795 loc, result_types, args, b); 796 } 797 798 template <> 799 inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc, 800 ArrayRef<Type> result_types, 801 ArrayRef<Value> args, 802 OpBuilder* b) { 803 return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}( 804 loc, result_types, args, b); 805 } 806 807 template <> 808 inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc, 809 ArrayRef<Type> result_types, 810 ArrayRef<Value> args, 811 OpBuilder* b) { 812 return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}( 813 loc, result_types, args, b); 814 } 815 816 } // namespace impl 817 818 struct HloOpToStdScalarOp { 819 // Implementation for LHLO ops except lmhlo::CompareOp. 820 template <typename HloOpTy, typename LhloOpTy = HloOpTy, 821 typename = std::enable_if_t< 822 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 823 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, 824 std::false_type>::value>> 825 static Value map(HloOpTy op, ArrayRef<Type> result_types, 826 ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { 827 return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, 828 args, b); 829 } 830 831 // Implementation for HLO ops except mhlo::CompareOp. 832 template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>, 833 typename = std::enable_if_t< 834 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 835 !std::is_same<LhloOpTy, std::false_type>::value>> 836 static Value map(HloOpTy op, ArrayRef<Type> result_types, 837 ArrayRef<Value> args, OpBuilder* b, int i = 0) { 838 return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, 839 args, b); 840 } 841 842 // Implementation for lmhlo::CompareOp. 843 template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< 844 LhloOpTy, lmhlo::CompareOp>::value>> 845 static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types, 846 ArrayRef<Value> args, OpBuilder* b) { 847 auto comparison_direction = op.comparison_direction(); 848 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 849 op.getLoc(), comparison_direction, result_types, args, b); 850 } 851 852 // Implementation for mhlo::CompareOp. 853 template <typename HloOpTy, 854 typename = 855 std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>> 856 static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types, 857 ArrayRef<Value> args, OpBuilder* b) { 858 auto comparison_direction = op.comparison_direction(); 859 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 860 op.getLoc(), comparison_direction, result_types, args, b); 861 } 862 863 // Implementation for LHLO ops except lmhlo::CompareOp. 864 template <typename LhloOpTy, 865 typename = std::enable_if_t< 866 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 867 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, 868 std::false_type>::value>> 869 static Value map(Location loc, ArrayRef<Type> result_types, 870 ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { 871 return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b); 872 } 873 874 // Implementation for lmhlo::CompareOp. 875 template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< 876 LhloOpTy, lmhlo::CompareOp>::value>> 877 static Value map(Location loc, StringRef comparison_direction, 878 ArrayRef<Type> result_types, ArrayRef<Value> args, 879 OpBuilder* b) { 880 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 881 loc, comparison_direction, result_types, args, b); 882 } 883 }; 884 885 } // namespace lmhlo 886 } // namespace mlir 887 888 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 889