1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17 
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19 
20 #include "mlir/IR/AffineMap.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28 
29 namespace xla {
30 namespace {
31 
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38 
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41     const ShapedType& type, const LiteralBase& literal) {
42   auto data_span = literal.data<CppType>();
43   return ::mlir::DenseElementsAttr::get(
44       type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46 
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
48     const Shape& shape, mlir::Builder builder) {
49   if (!shape.has_layout() ||
50       LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
51     return llvm::SmallVector<AffineMap, 1>{};
52   }
53   if (!shape.is_static()) {
54     return tensorflow::errors::Internal(
55         "Permutations for dynamic shapes are not yet supported");
56   }
57   int64_t accumulated_stride = 1;
58   llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
59   for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
60     strides[dim] = accumulated_stride;
61     accumulated_stride *= shape.dimensions(dim);
62   }
63   if (accumulated_stride == 0) {
64     return llvm::SmallVector<AffineMap, 1>{};
65   }
66   return llvm::SmallVector<AffineMap, 1>{
67       makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
68 }
69 
70 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8> * output)71 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
72                          std::vector<uint8>* output) {
73   output->resize(data.getNumElements() * sizeof(T));
74   int i = 0;
75   for (T element : data.getValues<T>()) {
76     std::memcpy(&(*output)[i], &element, sizeof(T));
77     i += sizeof(T);
78   }
79 }
80 
81 }  // namespace
82 
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)83 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
84     const Shape& shape, mlir::Builder builder) {
85   auto element_type_or =
86       ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
87   if (!element_type_or.ok()) return element_type_or.status();
88 
89   using mlir::MemRefType;
90   auto dimensions = shape.dimensions();
91   llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
92   auto permutation_or = GetPermutationIfAvailable(shape, builder);
93   if (!permutation_or.ok()) return permutation_or.status();
94   return MemRefType::get(array, element_type_or.ValueOrDie(),
95                          permutation_or.ValueOrDie());
96 }
97 
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)98 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
99     const LiteralBase& literal, Builder builder) {
100   TF_ASSIGN_OR_RETURN(auto type,
101                       ConvertTensorShapeToType<mlir::RankedTensorType>(
102                           literal.shape(), builder));
103 
104   // TODO(hinsu): Support remaining XLA primitive types.
105   auto element_type = literal.shape().element_type();
106   switch (element_type) {
107     case PrimitiveType::PRED:
108       return CreateDenseAttrFromLiteral<bool>(type, literal);
109     case PrimitiveType::F16:
110       return CreateDenseAttrFromLiteral<half>(type, literal);
111     case PrimitiveType::BF16:
112       return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
113     case PrimitiveType::F32:
114       return CreateDenseAttrFromLiteral<float>(type, literal);
115     case PrimitiveType::F64:
116       return CreateDenseAttrFromLiteral<double>(type, literal);
117     case PrimitiveType::S8:
118       return CreateDenseAttrFromLiteral<int8>(type, literal);
119     case PrimitiveType::S16:
120       return CreateDenseAttrFromLiteral<int16>(type, literal);
121     case PrimitiveType::S32:
122       return CreateDenseAttrFromLiteral<int32>(type, literal);
123     case PrimitiveType::S64:
124       return CreateDenseAttrFromLiteral<int64>(type, literal);
125     case PrimitiveType::U8:
126       return CreateDenseAttrFromLiteral<uint8>(type, literal);
127     case PrimitiveType::U16:
128       return CreateDenseAttrFromLiteral<uint16>(type, literal);
129     case PrimitiveType::U32:
130       return CreateDenseAttrFromLiteral<uint32>(type, literal);
131     case PrimitiveType::U64:
132       return CreateDenseAttrFromLiteral<uint64>(type, literal);
133     case PrimitiveType::C64:
134       return CreateDenseAttrFromLiteral<complex64>(type, literal);
135     case PrimitiveType::C128:
136       return CreateDenseAttrFromLiteral<complex128>(type, literal);
137     default:
138       return tensorflow::errors::Internal(
139           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
140   }
141 }
142 
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8> * output)143 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
144                                         std::vector<uint8>* output) {
145   mlir::Type element_type = data.getType().getElementType();
146 
147   // TODO(hinsu): Support remaining XLA primitive types.
148   if (element_type.isInteger(1)) {
149     CopyDenseElementsBy<bool>(data, output);
150     return Status::OK();
151   }
152   if (element_type.isInteger(8)) {
153     CopyDenseElementsBy<uint8>(data, output);
154     return Status::OK();
155   }
156   if (element_type.isInteger(16)) {
157     CopyDenseElementsBy<uint16>(data, output);
158     return Status::OK();
159   }
160   if (element_type.isInteger(32)) {
161     CopyDenseElementsBy<uint32>(data, output);
162     return Status::OK();
163   }
164   if (element_type.isInteger(64)) {
165     CopyDenseElementsBy<uint64>(data, output);
166     return Status::OK();
167   }
168   if (element_type.isBF16()) {
169     CopyDenseElementsBy<bfloat16>(data, output);
170     return Status::OK();
171   }
172   if (element_type.isF16()) {
173     CopyDenseElementsBy<half>(data, output);
174     return Status::OK();
175   }
176   if (element_type.isF32()) {
177     CopyDenseElementsBy<float>(data, output);
178     return Status::OK();
179   }
180   if (element_type.isF64()) {
181     CopyDenseElementsBy<double>(data, output);
182     return Status::OK();
183   }
184   if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
185     if (complex_type.getElementType().isF32()) {
186       CopyDenseElementsBy<complex64>(data, output);
187       return Status::OK();
188     }
189     if (complex_type.getElementType().isF64()) {
190       CopyDenseElementsBy<complex128>(data, output);
191       return Status::OK();
192     }
193   }
194   return tensorflow::errors::Internal(
195       "Unsupported type in CopyDenseElementsDataToXlaFormat");
196 }
197 
GetElementTypeBytes(mlir::Type type)198 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
199   if (type.isInteger(1)) {
200     return 1;
201   }
202   if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
203     TF_ASSIGN_OR_RETURN(int bytes,
204                         GetElementTypeBytes(complex_type.getElementType()));
205     return bytes * 2;
206   }
207   int width = type.getIntOrFloatBitWidth();
208   TF_RET_CHECK(width % 8 == 0);
209   return width / 8;
210 }
211 
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)212 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
213     const llvm::ArrayRef<int64> vector, mlir::Builder builder,
214     llvm::ArrayRef<int64_t> shape) {
215   return mlir::DenseIntElementsAttr::get(
216       mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
217                                   builder.getIntegerType(64)),
218       vector);
219 }
220 
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)221 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
222                                                     mlir::Builder builder) {
223   switch (element_type) {
224     case PrimitiveType::PRED:
225       return builder.getI1Type();
226     case PrimitiveType::F16:
227       return builder.getF16Type();
228     case PrimitiveType::BF16:
229       return builder.getBF16Type();
230     case PrimitiveType::F32:
231       return builder.getF32Type();
232     case PrimitiveType::F64:
233       return builder.getF64Type();
234     case PrimitiveType::S8:
235       return builder.getIntegerType(8);
236     case PrimitiveType::S16:
237       return builder.getIntegerType(16);
238     case PrimitiveType::S32:
239       return builder.getIntegerType(32);
240     case PrimitiveType::S64:
241       return builder.getIntegerType(64);
242     case PrimitiveType::U8:
243       return builder.getIntegerType(8, /*isSigned=*/false);
244     case PrimitiveType::U16:
245       return builder.getIntegerType(16, /*isSigned=*/false);
246     case PrimitiveType::U32:
247       return builder.getIntegerType(32, /*isSigned=*/false);
248     case PrimitiveType::U64:
249       return builder.getIntegerType(64, /*isSigned=*/false);
250     case PrimitiveType::C64:
251       return mlir::ComplexType::get(builder.getF32Type());
252     case PrimitiveType::C128:
253       return mlir::ComplexType::get(builder.getF64Type());
254     // TODO(b/130356985): Support unsigned primitive types.
255     default:
256       return tensorflow::errors::Internal(
257           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
258   }
259 }
260 
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)261 mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
262     const GatherDimensionNumbers& input, mlir::Builder builder) {
263   auto offset_dims = CreateDenseIntElementsAttrFromVector(
264       llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
265                                   input.offset_dims().end()},
266       builder);
267   auto collapsed_slice_dims = CreateDenseIntElementsAttrFromVector(
268       llvm::SmallVector<int64, 4>{input.collapsed_slice_dims().begin(),
269                                   input.collapsed_slice_dims().end()},
270       builder);
271   auto start_index_map = CreateDenseIntElementsAttrFromVector(
272       llvm::SmallVector<int64, 4>{input.start_index_map().begin(),
273                                   input.start_index_map().end()},
274       builder);
275 
276   mlir::IntegerAttr index_vector_dim =
277       builder.getI64IntegerAttr(input.index_vector_dim());
278 
279   return mlir::mhlo::GatherDimensionNumbers::get(
280       offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
281       builder.getContext());
282 }
283 
MhloToHloOpcode(mlir::Operation * op)284 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
285   using mlir::isa;
286 
287   if (isa<mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op)) {
288     return xla::HloOpcode::kConstant;
289   } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
290     return xla::HloOpcode::kIota;
291   } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
292     return xla::HloOpcode::kConvert;
293   } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
294     return xla::HloOpcode::kAdd;
295   } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
296     return xla::HloOpcode::kAtan2;
297   } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
298     return xla::HloOpcode::kDivide;
299   } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
300     return xla::HloOpcode::kMaximum;
301   } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
302     return xla::HloOpcode::kMinimum;
303   } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
304     return xla::HloOpcode::kMultiply;
305   } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
306     return xla::HloOpcode::kPower;
307   } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
308     return xla::HloOpcode::kRemainder;
309   } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
310     return xla::HloOpcode::kShiftLeft;
311   } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
312                  mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
313     return xla::HloOpcode::kShiftRightArithmetic;
314   } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
315                  mlir::lmhlo::ShiftRightLogicalOp>(op)) {
316     return xla::HloOpcode::kShiftRightLogical;
317   } else if (isa<mlir::mhlo::SubOp, mlir::lmhlo::SubOp>(op)) {
318     return xla::HloOpcode::kSubtract;
319   } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
320     return xla::HloOpcode::kXor;
321   } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
322     return xla::HloOpcode::kInfeed;
323   } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
324     return xla::HloOpcode::kOutfeed;
325   } else if (isa<mlir::mhlo::SendOp>(op)) {
326     return xla::HloOpcode::kSend;
327   } else if (isa<mlir::mhlo::RecvOp>(op)) {
328     return xla::HloOpcode::kRecv;
329   } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
330     return xla::HloOpcode::kReplicaId;
331   } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
332     return xla::HloOpcode::kAfterAll;
333   } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
334     return xla::HloOpcode::kAllReduce;
335   } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
336     return xla::HloOpcode::kAllToAll;
337   } else if (isa<mlir::mhlo::TupleOp>(op)) {
338     return xla::HloOpcode::kTuple;
339   } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
340                  op)) {
341     return xla::HloOpcode::kBatchNormGrad;
342   } else if (isa<mlir::mhlo::BatchNormInferenceOp,
343                  mlir::lmhlo::BatchNormInferenceOp>(op)) {
344     return xla::HloOpcode::kBatchNormInference;
345   } else if (isa<mlir::mhlo::BatchNormTrainingOp,
346                  mlir::lmhlo::BatchNormTrainingOp>(op)) {
347     return xla::HloOpcode::kBatchNormTraining;
348   } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
349                  op)) {
350     return xla::HloOpcode::kBitcastConvert;
351   } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
352     return xla::HloOpcode::kBroadcast;
353   } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
354     return xla::HloOpcode::kCholesky;
355   } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
356     return xla::HloOpcode::kClamp;
357   } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
358     return xla::HloOpcode::kConcatenate;
359   } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
360     return xla::HloOpcode::kConvolution;
361   } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
362     return xla::HloOpcode::kSort;
363   } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
364     return xla::HloOpcode::kRngBitGenerator;
365   } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
366     return xla::HloOpcode::kFusion;
367   } else if (isa<mlir::mhlo::BitcastOp>(op)) {
368     return xla::HloOpcode::kBitcast;
369   } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
370     return xla::HloOpcode::kAbs;
371   } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
372     return xla::HloOpcode::kCbrt;
373   } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
374     return xla::HloOpcode::kCeil;
375   } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
376     return xla::HloOpcode::kClz;
377   } else if (isa<mlir::mhlo::CosOp, mlir::lmhlo::CosOp>(op)) {
378     return xla::HloOpcode::kCos;
379   } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
380     return xla::HloOpcode::kExp;
381   } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
382     return xla::HloOpcode::kExpm1;
383   } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
384     return xla::HloOpcode::kFloor;
385   } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
386     return xla::HloOpcode::kImag;
387   } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
388     return xla::HloOpcode::kIsFinite;
389   } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
390     return xla::HloOpcode::kLog;
391   } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
392     return xla::HloOpcode::kLog1p;
393   } else if (isa<mlir::mhlo::LogisticOp>(op)) {
394     return xla::HloOpcode::kLogistic;
395   } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
396     return xla::HloOpcode::kNot;
397   } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
398     return xla::HloOpcode::kNegate;
399   } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
400                  op)) {
401     return xla::HloOpcode::kPopulationCount;
402   } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
403     return xla::HloOpcode::kReal;
404   } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
405     return xla::HloOpcode::kRoundNearestAfz;
406   } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
407     return xla::HloOpcode::kRsqrt;
408   } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
409     return xla::HloOpcode::kSign;
410   } else if (isa<mlir::mhlo::SinOp, mlir::lmhlo::SinOp>(op)) {
411     return xla::HloOpcode::kSin;
412   } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
413     return xla::HloOpcode::kSqrt;
414   } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
415     return xla::HloOpcode::kTanh;
416   } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
417     return xla::HloOpcode::kComplex;
418   } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
419     return xla::HloOpcode::kAnd;
420   } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
421     return xla::HloOpcode::kOr;
422   } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
423     return xla::HloOpcode::kWhile;
424   } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
425     return xla::HloOpcode::kReduce;
426   } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
427     return xla::HloOpcode::kGetTupleElement;
428   } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
429     return xla::HloOpcode::kCompare;
430   } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
431     return xla::HloOpcode::kSlice;
432   } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
433     return xla::HloOpcode::kDynamicSlice;
434   } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
435                  mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
436     return xla::HloOpcode::kDynamicUpdateSlice;
437   } else if (isa<mlir::mhlo::CollectivePermuteOp,
438                  mlir::lmhlo::CollectivePermuteOp>(op)) {
439     return xla::HloOpcode::kCollectivePermute;
440   } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
441     return xla::HloOpcode::kCopy;
442   } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
443     return xla::HloOpcode::kCustomCall;
444   } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
445     return xla::HloOpcode::kDot;
446   } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
447     return xla::HloOpcode::kFft;
448   } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
449     return xla::HloOpcode::kGather;
450   } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
451     return xla::HloOpcode::kGetDimensionSize;
452   } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
453     return xla::HloOpcode::kMap;
454   } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
455     return xla::HloOpcode::kReshape;
456   } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
457     return xla::HloOpcode::kDynamicReshape;
458   } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
459     return xla::HloOpcode::kScatter;
460   } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
461     return xla::HloOpcode::kSelect;
462   } else if (isa<mlir::mhlo::SelectAndScatterOp,
463                  mlir::lmhlo::SelectAndScatterOp>(op)) {
464     return xla::HloOpcode::kSelectAndScatter;
465   } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
466     return xla::HloOpcode::kSetDimensionSize;
467   } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
468     return xla::HloOpcode::kReverse;
469   } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
470     return xla::HloOpcode::kPad;
471   } else if (isa<mlir::mhlo::TraceOp>(op)) {
472     return xla::HloOpcode::kTrace;
473   } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
474     return xla::HloOpcode::kTranspose;
475   } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
476                  op)) {
477     return xla::HloOpcode::kTriangularSolve;
478   } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
479     return xla::HloOpcode::kReduceWindow;
480   } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
481                  op)) {
482     return xla::HloOpcode::kReducePrecision;
483   } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
484     return xla::HloOpcode::kDot;
485   } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
486                  op)) {
487     return xla::HloOpcode::kBroadcast;
488   } else {
489     std::string s;
490     {
491       llvm::raw_string_ostream os(s);
492       op->print(os);
493     }
494     return tensorflow::errors::Unimplemented(
495         "Unimplemented MHLO -> HloOpcode: ", s);
496   }
497 }
498 
499 }  // namespace xla
500