1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24
25 namespace tensorflow {
26 namespace {
27
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29 std::vector<int>* layouts) {
30 if (shape.IsTuple()) {
31 int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32 for (int64 i = 0; i < tuple_elements; ++i) {
33 const xla::Shape& subshape =
34 xla::ShapeUtil::GetTupleElementShape(shape, i);
35 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36 }
37 } else if (xla::LayoutUtil::HasLayout(shape)) {
38 for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39 layouts->push_back(dim);
40 }
41 } else {
42 layouts->insert(layouts->end(), shape.rank(), -1);
43 }
44 return Status::OK();
45 }
46
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64> minor_to_major,xla::Layout * layout)50 xla::StatusOr<bool> MakeLayout(absl::Span<const int64> minor_to_major,
51 xla::Layout* layout) {
52 if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53 [](int64 dim) { return dim == -1; })) {
54 return false;
55 }
56 std::vector<bool> dim_present(minor_to_major.size(), false);
57 for (auto dim : minor_to_major) {
58 if (dim < 0 || dim >= minor_to_major.size()) {
59 return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
60 " rank=", minor_to_major.size());
61 }
62 if (dim_present[dim]) {
63 return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
64 }
65 dim_present[dim] = true;
66 }
67 *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
68 return true;
69 }
70
AssignLayout(absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)71 Status AssignLayout(
72 absl::Span<const int64> minor_to_major,
73 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
74 xla::Shape* shape) {
75 xla::Layout layout;
76 TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
77 if (!has_layout && layout_func) {
78 layout = layout_func(*shape);
79 }
80 *shape->mutable_layout() = layout;
81 return Status::OK();
82 }
83
84 } // namespace
85
86 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)87 Status XLAShapeToTensorShape(const xla::Shape& shape,
88 TensorShape* tensor_shape) {
89 if (shape.IsTuple()) {
90 return errors::InvalidArgument("XLA shape ",
91 xla::ShapeUtil::HumanString(shape),
92 " cannot be converted to a TensorShape");
93 }
94 *tensor_shape = TensorShape();
95 for (int i = 0; i < shape.rank(); ++i) {
96 tensor_shape->AddDim(shape.dimensions(i));
97 }
98 return Status::OK();
99 }
100
101 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)102 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
103 xla::Shape* shape) {
104 xla::PrimitiveType type;
105 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
106 *shape = TensorShapeToXLAShape(type, tensor_shape);
107 return Status::OK();
108 }
109
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)110 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
111 const TensorShape& tensor_shape) {
112 int rank = tensor_shape.dims();
113 std::vector<int64> dimensions(rank);
114 std::vector<int64> layout(rank);
115 for (int d = 0; d < rank; ++d) {
116 dimensions[d] = tensor_shape.dim_size(d);
117 }
118 // XLA uses minor-to-major; Tensorflow uses major-to-minor.
119 std::iota(layout.rbegin(), layout.rend(), 0);
120
121 return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
122 }
123
GetShapeLayoutVector(const xla::Shape & shape)124 xla::StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
125 std::vector<int> layouts;
126 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
127 return layouts;
128 }
129
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)130 Status GetShapeWithLayout(
131 const xla::Shape& input_shape, absl::Span<const int64> minor_to_major,
132 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
133 xla::Shape* output_shape) {
134 if (input_shape.IsTuple()) {
135 int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
136 std::vector<xla::Shape> shapes;
137 shapes.reserve(tuple_elements);
138 size_t position = 0;
139 for (int64 i = 0; i < tuple_elements; ++i) {
140 const xla::Shape& shape =
141 xla::ShapeUtil::GetTupleElementShape(input_shape, i);
142 if (shape.IsTuple()) {
143 return errors::InvalidArgument(
144 "Nested tuples not supported: ",
145 xla::ShapeUtil::HumanString(input_shape));
146 }
147 int64 rank = shape.rank();
148 if (position + rank > minor_to_major.size()) {
149 return errors::InvalidArgument(
150 "Not enough layout attribute elements: position=", position,
151 " rank=", rank, " elements=", minor_to_major.size());
152 }
153 shapes.push_back(shape);
154 TF_RETURN_IF_ERROR(AssignLayout(
155 absl::Span<const int64>(minor_to_major).subspan(position, rank),
156 layout_func, &shapes.back()));
157 position += rank;
158
159 VLOG(4) << "Shape[" << i
160 << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
161 }
162 if (position != minor_to_major.size()) {
163 return errors::InvalidArgument(
164 "Too many elements passed in the layout attribute: position=",
165 position, " size=", minor_to_major.size());
166 }
167 *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
168 } else {
169 int64 rank = input_shape.rank();
170 if (rank != minor_to_major.size()) {
171 return errors::InvalidArgument(
172 "Wrong number of layout attribute elements: rank=", rank,
173 " elements=", minor_to_major.size());
174 }
175 *output_shape = input_shape;
176 TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
177
178 VLOG(4) << "Shape[] = "
179 << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
180 }
181 return Status::OK();
182 }
183
184 } // namespace tensorflow
185