/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This file defines helpers useful when creating or manipulating lhlo/hlo. #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace { using mlir::AffineMap; using mlir::Builder; using mlir::DenseElementsAttr; using mlir::ShapedType; using xla::LiteralBase; using xla::StatusOr; template ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( const ShapedType& type, const LiteralBase& literal) { auto data_span = literal.data(); return ::mlir::DenseElementsAttr::get( type, llvm::makeArrayRef(data_span.data(), data_span.size())); } StatusOr> GetPermutationIfAvailable( const Shape& shape, mlir::Builder builder) { if (!shape.has_layout() || LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { return llvm::SmallVector{}; } if (!shape.is_static()) { return tensorflow::errors::Internal( "Permutations for dynamic shapes are not yet supported"); } int64_t accumulated_stride = 1; llvm::SmallVector strides(shape.rank(), 1); for (int64 dim : LayoutUtil::MinorToMajor(shape)) { strides[dim] = accumulated_stride; accumulated_stride *= shape.dimensions(dim); } if (accumulated_stride == 0) { return llvm::SmallVector{}; } return llvm::SmallVector{ makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())}; } template void CopyDenseElementsBy(mlir::DenseElementsAttr data, std::vector* output) { output->resize(data.getNumElements() * sizeof(T)); int i = 0; for (T element : data.getValues()) { std::memcpy(&(*output)[i], &element, sizeof(T)); i += sizeof(T); } } } // namespace StatusOr ConvertTensorShapeToMemRefType( const Shape& shape, mlir::Builder builder) { auto element_type_or = ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); using mlir::MemRefType; auto dimensions = shape.dimensions(); llvm::SmallVector array(dimensions.begin(), dimensions.end()); auto permutation_or = GetPermutationIfAvailable(shape, builder); if (!permutation_or.ok()) return permutation_or.status(); return MemRefType::get(array, element_type_or.ValueOrDie(), permutation_or.ValueOrDie()); } StatusOr CreateDenseElementsAttrFromLiteral( const LiteralBase& literal, Builder builder) { TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType( literal.shape(), builder)); // TODO(hinsu): Support remaining XLA primitive types. auto element_type = literal.shape().element_type(); switch (element_type) { case PrimitiveType::PRED: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::F16: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::BF16: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::F32: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::F64: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S8: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S16: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S32: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S64: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U8: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U16: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U32: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U64: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::C64: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::C128: return CreateDenseAttrFromLiteral(type, literal); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); } } Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, std::vector* output) { mlir::Type element_type = data.getType().getElementType(); // TODO(hinsu): Support remaining XLA primitive types. if (element_type.isInteger(1)) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isInteger(8)) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isInteger(16)) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isInteger(32)) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isInteger(64)) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isBF16()) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isF16()) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isF32()) { CopyDenseElementsBy(data, output); return Status::OK(); } if (element_type.isF64()) { CopyDenseElementsBy(data, output); return Status::OK(); } if (auto complex_type = element_type.dyn_cast()) { if (complex_type.getElementType().isF32()) { CopyDenseElementsBy(data, output); return Status::OK(); } if (complex_type.getElementType().isF64()) { CopyDenseElementsBy(data, output); return Status::OK(); } } return tensorflow::errors::Internal( "Unsupported type in CopyDenseElementsDataToXlaFormat"); } StatusOr GetElementTypeBytes(mlir::Type type) { if (type.isInteger(1)) { return 1; } if (auto complex_type = type.dyn_cast()) { TF_ASSIGN_OR_RETURN(int bytes, GetElementTypeBytes(complex_type.getElementType())); return bytes * 2; } int width = type.getIntOrFloatBitWidth(); TF_RET_CHECK(width % 8 == 0); return width / 8; } mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( const llvm::ArrayRef vector, mlir::Builder builder, llvm::ArrayRef shape) { return mlir::DenseIntElementsAttr::get( mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape, builder.getIntegerType(64)), vector); } StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, mlir::Builder builder) { switch (element_type) { case PrimitiveType::PRED: return builder.getI1Type(); case PrimitiveType::F16: return builder.getF16Type(); case PrimitiveType::BF16: return builder.getBF16Type(); case PrimitiveType::F32: return builder.getF32Type(); case PrimitiveType::F64: return builder.getF64Type(); case PrimitiveType::S8: return builder.getIntegerType(8); case PrimitiveType::S16: return builder.getIntegerType(16); case PrimitiveType::S32: return builder.getIntegerType(32); case PrimitiveType::S64: return builder.getIntegerType(64); case PrimitiveType::U8: return builder.getIntegerType(8, /*isSigned=*/false); case PrimitiveType::U16: return builder.getIntegerType(16, /*isSigned=*/false); case PrimitiveType::U32: return builder.getIntegerType(32, /*isSigned=*/false); case PrimitiveType::U64: return builder.getIntegerType(64, /*isSigned=*/false); case PrimitiveType::C64: return mlir::ComplexType::get(builder.getF32Type()); case PrimitiveType::C128: return mlir::ComplexType::get(builder.getF64Type()); // TODO(b/130356985): Support unsigned primitive types. default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); } } mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers( const GatherDimensionNumbers& input, mlir::Builder builder) { auto offset_dims = CreateDenseIntElementsAttrFromVector( llvm::SmallVector{input.offset_dims().begin(), input.offset_dims().end()}, builder); auto collapsed_slice_dims = CreateDenseIntElementsAttrFromVector( llvm::SmallVector{input.collapsed_slice_dims().begin(), input.collapsed_slice_dims().end()}, builder); auto start_index_map = CreateDenseIntElementsAttrFromVector( llvm::SmallVector{input.start_index_map().begin(), input.start_index_map().end()}, builder); mlir::IntegerAttr index_vector_dim = builder.getI64IntegerAttr(input.index_vector_dim()); return mlir::mhlo::GatherDimensionNumbers::get( offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim, builder.getContext()); } StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { using mlir::isa; if (isa(op)) { return xla::HloOpcode::kConstant; } else if (isa(op)) { return xla::HloOpcode::kIota; } else if (isa(op)) { return xla::HloOpcode::kConvert; } else if (isa(op)) { return xla::HloOpcode::kAdd; } else if (isa(op)) { return xla::HloOpcode::kAtan2; } else if (isa(op)) { return xla::HloOpcode::kDivide; } else if (isa(op)) { return xla::HloOpcode::kMaximum; } else if (isa(op)) { return xla::HloOpcode::kMinimum; } else if (isa(op)) { return xla::HloOpcode::kMultiply; } else if (isa(op)) { return xla::HloOpcode::kPower; } else if (isa(op)) { return xla::HloOpcode::kRemainder; } else if (isa(op)) { return xla::HloOpcode::kShiftLeft; } else if (isa(op)) { return xla::HloOpcode::kShiftRightArithmetic; } else if (isa(op)) { return xla::HloOpcode::kShiftRightLogical; } else if (isa(op)) { return xla::HloOpcode::kSubtract; } else if (isa(op)) { return xla::HloOpcode::kXor; } else if (isa(op)) { return xla::HloOpcode::kInfeed; } else if (isa(op)) { return xla::HloOpcode::kOutfeed; } else if (isa(op)) { return xla::HloOpcode::kSend; } else if (isa(op)) { return xla::HloOpcode::kRecv; } else if (isa(op)) { return xla::HloOpcode::kReplicaId; } else if (isa(op)) { return xla::HloOpcode::kAfterAll; } else if (isa(op)) { return xla::HloOpcode::kAllReduce; } else if (isa(op)) { return xla::HloOpcode::kAllToAll; } else if (isa(op)) { return xla::HloOpcode::kTuple; } else if (isa( op)) { return xla::HloOpcode::kBatchNormGrad; } else if (isa(op)) { return xla::HloOpcode::kBatchNormInference; } else if (isa(op)) { return xla::HloOpcode::kBatchNormTraining; } else if (isa( op)) { return xla::HloOpcode::kBitcastConvert; } else if (isa(op)) { return xla::HloOpcode::kBroadcast; } else if (isa(op)) { return xla::HloOpcode::kCholesky; } else if (isa(op)) { return xla::HloOpcode::kClamp; } else if (isa(op)) { return xla::HloOpcode::kConcatenate; } else if (isa(op)) { return xla::HloOpcode::kConvolution; } else if (isa(op)) { return xla::HloOpcode::kSort; } else if (isa(op)) { return xla::HloOpcode::kRngBitGenerator; } else if (isa(op)) { return xla::HloOpcode::kFusion; } else if (isa(op)) { return xla::HloOpcode::kBitcast; } else if (isa(op)) { return xla::HloOpcode::kAbs; } else if (isa(op)) { return xla::HloOpcode::kCbrt; } else if (isa(op)) { return xla::HloOpcode::kCeil; } else if (isa(op)) { return xla::HloOpcode::kClz; } else if (isa(op)) { return xla::HloOpcode::kCos; } else if (isa(op)) { return xla::HloOpcode::kExp; } else if (isa(op)) { return xla::HloOpcode::kExpm1; } else if (isa(op)) { return xla::HloOpcode::kFloor; } else if (isa(op)) { return xla::HloOpcode::kImag; } else if (isa(op)) { return xla::HloOpcode::kIsFinite; } else if (isa(op)) { return xla::HloOpcode::kLog; } else if (isa(op)) { return xla::HloOpcode::kLog1p; } else if (isa(op)) { return xla::HloOpcode::kLogistic; } else if (isa(op)) { return xla::HloOpcode::kNot; } else if (isa(op)) { return xla::HloOpcode::kNegate; } else if (isa( op)) { return xla::HloOpcode::kPopulationCount; } else if (isa(op)) { return xla::HloOpcode::kReal; } else if (isa(op)) { return xla::HloOpcode::kRoundNearestAfz; } else if (isa(op)) { return xla::HloOpcode::kRsqrt; } else if (isa(op)) { return xla::HloOpcode::kSign; } else if (isa(op)) { return xla::HloOpcode::kSin; } else if (isa(op)) { return xla::HloOpcode::kSqrt; } else if (isa(op)) { return xla::HloOpcode::kTanh; } else if (isa(op)) { return xla::HloOpcode::kComplex; } else if (isa(op)) { return xla::HloOpcode::kAnd; } else if (isa(op)) { return xla::HloOpcode::kOr; } else if (isa(op)) { return xla::HloOpcode::kWhile; } else if (isa(op)) { return xla::HloOpcode::kReduce; } else if (isa(op)) { return xla::HloOpcode::kGetTupleElement; } else if (isa(op)) { return xla::HloOpcode::kCompare; } else if (isa(op)) { return xla::HloOpcode::kSlice; } else if (isa(op)) { return xla::HloOpcode::kDynamicSlice; } else if (isa(op)) { return xla::HloOpcode::kDynamicUpdateSlice; } else if (isa(op)) { return xla::HloOpcode::kCollectivePermute; } else if (isa(op)) { return xla::HloOpcode::kCopy; } else if (isa(op)) { return xla::HloOpcode::kCustomCall; } else if (isa(op)) { return xla::HloOpcode::kDot; } else if (isa(op)) { return xla::HloOpcode::kFft; } else if (isa(op)) { return xla::HloOpcode::kGather; } else if (isa(op)) { return xla::HloOpcode::kGetDimensionSize; } else if (isa(op)) { return xla::HloOpcode::kMap; } else if (isa(op)) { return xla::HloOpcode::kReshape; } else if (isa(op)) { return xla::HloOpcode::kDynamicReshape; } else if (isa(op)) { return xla::HloOpcode::kScatter; } else if (isa(op)) { return xla::HloOpcode::kSelect; } else if (isa(op)) { return xla::HloOpcode::kSelectAndScatter; } else if (isa(op)) { return xla::HloOpcode::kSetDimensionSize; } else if (isa(op)) { return xla::HloOpcode::kReverse; } else if (isa(op)) { return xla::HloOpcode::kPad; } else if (isa(op)) { return xla::HloOpcode::kTrace; } else if (isa(op)) { return xla::HloOpcode::kTranspose; } else if (isa( op)) { return xla::HloOpcode::kTriangularSolve; } else if (isa(op)) { return xla::HloOpcode::kReduceWindow; } else if (isa( op)) { return xla::HloOpcode::kReducePrecision; } else if (isa(op)) { return xla::HloOpcode::kDot; } else if (isa( op)) { return xla::HloOpcode::kBroadcast; } else { std::string s; { llvm::raw_string_ostream os(s); op->print(os); } return tensorflow::errors::Unimplemented( "Unimplemented MHLO -> HloOpcode: ", s); } } } // namespace xla