1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
17 
18 #include <string>
19 
20 #include "mlir/IR/AffineMap.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/Location.h"  // from @llvm-project
24 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 using mlir::IntegerType;
36 using mlir::MemRefType;
37 using mlir::RankedTensorType;
38 using mlir::VectorType;
39 using tensorflow::int64;
40 using xla::PrimitiveType;
41 using xla::ShapeUtil;
42 
43 namespace xla {
44 
TypeToPrimitiveType(mlir::Type type)45 PrimitiveType TypeToPrimitiveType(mlir::Type type) {
46   if (type.isBF16()) {
47     return PrimitiveType::BF16;
48   } else if (type.isF16()) {
49     return PrimitiveType::F16;
50   } else if (type.isF32()) {
51     return PrimitiveType::F32;
52   } else if (type.isF64()) {
53     return PrimitiveType::F64;
54   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
55     mlir::Type element_ty = complex_type.getElementType();
56     if (element_ty.isF32()) {
57       return PrimitiveType::C64;
58 
59     } else if (element_ty.isF64()) {
60       return PrimitiveType::C128;
61     }
62     return PrimitiveType::PRIMITIVE_TYPE_INVALID;
63   } else if (auto integer_type = type.dyn_cast<mlir::IntegerType>()) {
64     bool is_unsigned = integer_type.isUnsigned();
65     switch (integer_type.getWidth()) {
66       case 1:
67         return PrimitiveType::PRED;
68       case 8:
69         return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
70       case 16:
71         return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
72       case 32:
73         return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
74       case 64:
75         return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
76       default:
77         return PrimitiveType::PRIMITIVE_TYPE_INVALID;
78     }
79   }
80   return PrimitiveType::PRIMITIVE_TYPE_INVALID;
81 }
82 
TypeToShape(mlir::Type type,CustomShapeRepresentationFn shape_representation_fn)83 StatusOr<Shape> TypeToShape(
84     mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) {
85   tensorflow::PartialTensorShape partial_tensor_shape =
86       tensorflow::ConvertTypeToTensorShape(type);
87 
88   tensorflow::TensorShape fully_defined_tensor_shape;
89   if (!partial_tensor_shape.AsTensorShape(&fully_defined_tensor_shape)) {
90     return tensorflow::errors::InvalidArgument(
91         "XLA HLO only allows fully-defined shape");
92   }
93 
94   tensorflow::DataType dtype;
95   TF_RETURN_IF_ERROR(tensorflow::ConvertToDataType(type, &dtype));
96 
97   return shape_representation_fn(fully_defined_tensor_shape, dtype);
98 }
99 
TypeToShape(mlir::Type type)100 Shape TypeToShape(mlir::Type type) {
101   PrimitiveType ptype = TypeToPrimitiveType(type);
102   if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
103     return ShapeUtil::MakeShape(ptype, {});
104 
105   if (type.isIntOrFloat()) {
106     auto* context = type.getContext();
107     mlir::emitError(mlir::UnknownLoc::get(context))
108         << "lowering should have been handled by primitive type lowering for "
109         << debugString(type);
110   } else if (auto v = type.dyn_cast<mlir::VectorType>()) {
111     llvm::SmallVector<int64, 4> span(v.getShape().begin(), v.getShape().end());
112     mlir::Type element_type = v.getElementType();
113     PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
114     if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
115       return ShapeUtil::MakeShape(primitive_type, span);
116   } else if (auto m = type.dyn_cast<mlir::MemRefType>()) {
117     llvm::SmallVector<int64, 6> span(m.getShape().begin(), m.getShape().end());
118     mlir::Type element_type = m.getElementType();
119     // Treat a memref of a vector as if it was a memref of primitive type with
120     // the vector dimensions at the end.
121     if (auto v = element_type.dyn_cast<mlir::VectorType>()) {
122       element_type = v.getElementType();
123       span.insert(span.end(), v.getShape().begin(), v.getShape().end());
124     }
125     PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
126     if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
127     // For the primitive type case, the shape of the memref is similar to the
128     // vector type case (i.e., it is, modulo the layout, the same dimensions
129     // and primitive type).
130     if (m.getAffineMaps().empty())
131       return ShapeUtil::MakeShape(primitive_type, span);
132 
133     if (m.getAffineMaps().size() == 1) {
134       llvm::SmallVector<int64_t, 4> strides;
135       int64_t offset;
136       if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
137 
138       llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
139       for (const auto& e : llvm::enumerate(strides)) {
140         strides_with_indices.push_back({e.value(), e.index()});
141       }
142       std::stable_sort(strides_with_indices.begin(),
143                        strides_with_indices.end());
144 
145       llvm::SmallVector<int64, 4> minor_to_major;
146       int64_t stride = 1;
147       for (const auto& pr : strides_with_indices) {
148         minor_to_major.push_back(pr.second);
149 
150         // Either the affine map is not perfectly strided, or the dimensions
151         // recovered from strides don't match the actual dimensions in shapes.
152         if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
153 
154         stride *= m.getShape()[pr.second];
155       }
156 
157       llvm::SmallVector<int64, 4> dimensions(m.getShape().begin(),
158                                              m.getShape().end());
159       return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
160                                                    minor_to_major);
161     }
162   } else if (auto t = type.dyn_cast<mlir::RankedTensorType>()) {
163     // TODO(jpienaar): This is only handling the base case with primitive
164     // element type.
165     llvm::SmallVector<int64, 4> span(t.getShape().begin(), t.getShape().end());
166     // Only fully static shapes are supported.
167     // TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
168     if (std::find(t.getShape().begin(), t.getShape().end(), -1) !=
169         t.getShape().end())
170       return {};
171     mlir::Type element_type = t.getElementType();
172     PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
173     // Only primitive element type supported.
174     if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
175       return ShapeUtil::MakeShape(primitive_type, span);
176   } else if (auto tuple_type = type.dyn_cast<mlir::TupleType>()) {
177     llvm::SmallVector<Shape, 4> shapes;
178     shapes.reserve(tuple_type.size());
179     for (mlir::Type sub_type : tuple_type.getTypes()) {
180       shapes.push_back(TypeToShape(sub_type));
181     }
182     return ShapeUtil::MakeTupleShape(shapes);
183 
184   } else if (type.isa<mlir::mhlo::TokenType>()) {
185     return ShapeUtil::MakeTokenShape();
186   }
187 
188   // Return empty XLA shape to signify error. No MLIR Type maps to a empty
189   // Shape.
190   return {};
191 }
192 
193 }  // namespace xla
194