1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/layout.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace xla {
30 
31 // A shape describes the number of dimensions in a array, the bounds of each
32 // dimension, and the primitive component type. For tuples, shape describes the
33 // structure (number of elements and nesting).
34 class Shape {
35  public:
36   Shape() = default;
37 
38   // Construct a shape from a ShapeProto.
39   explicit Shape(const ShapeProto& shape_proto);
40 
41   // Returns a ShapeProto representation of the Shape.
42   ShapeProto ToProto() const;
43 
44   // Returns a human-readable string that represents the given shape, with or
45   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
46   string ToString(bool print_layout = false) const;
47 
48   // Returns the rank (number of dimensions) of the given shape. Shape must be
49   // an array.
rank()50   int64 rank() const {
51     CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
52     return dimensions_.size();
53   }
54 
55   // Returns whether the shape is of the specified type (array, tuple, etc).
IsArray()56   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
IsTuple()57   bool IsTuple() const { return element_type() == TUPLE; }
IsToken()58   bool IsToken() const { return element_type() == TOKEN; }
IsOpaque()59   bool IsOpaque() const { return element_type() == OPAQUE; }
60 
61   // Returns true if no array dimension in the shape is dynamically sized. Tuple
62   // shapes are traversed recursively.
63   bool is_static() const;
64 
65   // Returns true if the given dimension is dynamically-sized.
is_dynamic_dimension(int dimension)66   bool is_dynamic_dimension(int dimension) const {
67     return dynamic_dimensions_.at(dimension);
68   }
69 
70   // Sets whether or not the given dimension is dynamically-sized.
set_dynamic_dimension(int dimension,bool is_dynamic)71   void set_dynamic_dimension(int dimension, bool is_dynamic) {
72     dynamic_dimensions_[dimension] = is_dynamic;
73   }
74 
dynamic_dimensions()75   const std::vector<bool>& dynamic_dimensions() const {
76     return dynamic_dimensions_;
77   }
78 
79   // Add dimension_upper_bound().
80 
81   // Removes the given dimension form the shape. Layout, if it exists, is
82   // adjusted to match the modified shape.
83   void DeleteDimension(int64 dim_to_delete);
84 
85   // The following methods mirror the protobuf generated code interface for the
86   // message ShapeProto. This enabled easy migration of this data structure
87   // from a proto to a proper C++ class.
88   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
89   // interface.
90 
91   // Methods for accessing the primitive type.
element_type()92   PrimitiveType element_type() const { return element_type_; }
set_element_type(PrimitiveType value)93   void set_element_type(PrimitiveType value) { element_type_ = value; }
94 
95   // Methods for accessing the dimensions array.
dimensions_size()96   int dimensions_size() const { return dimensions_.size(); }
dimensions(int index)97   int64 dimensions(int index) const { return dimensions_.at(index); }
set_dimensions(int index,int64 value)98   void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
add_dimensions(int64 value)99   void add_dimensions(int64 value) {
100     dimensions_.push_back(value);
101     dynamic_dimensions_.push_back(false);
102   }
clear_dimensions()103   void clear_dimensions() {
104     dimensions_.clear();
105     dynamic_dimensions_.clear();
106   }
dimensions()107   const std::vector<int64>& dimensions() const { return dimensions_; }
mutable_dimensions()108   absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
109 
110   // Methods for accessing the tuple subshapes. This field only non-empty for
111   // tuple shapes.
tuple_shapes_size()112   int tuple_shapes_size() const { return tuple_shapes_.size(); }
tuple_shapes(int index)113   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
mutable_tuple_shapes(int index)114   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
add_tuple_shapes()115   Shape* add_tuple_shapes() {
116     tuple_shapes_.push_back(Shape());
117     return &tuple_shapes_.back();
118   }
clear_tuple_shapes()119   void clear_tuple_shapes() { tuple_shapes_.clear(); }
tuple_shapes()120   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
mutable_tuple_shapes()121   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
122 
123   // Methods for accessing the layout field.
has_layout()124   bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
layout()125   const Layout& layout() const { return layout_; }
mutable_layout()126   Layout* mutable_layout() { return &layout_; }
clear_layout()127   void clear_layout() { layout_.Clear(); }
128 
Swap(Shape * other)129   void Swap(Shape* other) {
130     using std::swap;
131     swap(*this, *other);
132   }
133 
Clear()134   void Clear() {
135     element_type_ = PRIMITIVE_TYPE_INVALID;
136     dimensions_.clear();
137     tuple_shapes_.clear();
138     clear_layout();
139   }
140 
SerializeAsString()141   string SerializeAsString() const { return ToProto().SerializeAsString(); }
ShortDebugString()142   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()143   string DebugString() const { return ToProto().DebugString(); }
144 
145   // Equal is a configurable functor to check the equality of two shapes.
146   //
147   // Examples:
148   //
149   // - Comparing two shapes ignoring their layout difference:
150   //   Equal().IgnoreLayout()(shape1, shape2);
151   //
152   // - Comparing two shapes ignoring their layout and element type difference:
153   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
154   class Equal {
155    public:
156     Equal() = default;
157 
158     bool operator()(const Shape& lhs, const Shape& rhs);
159 
IgnoreLayout()160     Equal& IgnoreLayout() {
161       ignore_layout_ = true;
162       return *this;
163     }
IgnoreTilesInLayout()164     Equal& IgnoreTilesInLayout() {
165       ignore_tiles_in_layout_ = true;
166       return *this;
167     }
IgnoreElementSizeInLayout()168     Equal& IgnoreElementSizeInLayout() {
169       ignore_element_size_in_layout_ = true;
170       return *this;
171     }
IgnoreElementType()172     Equal& IgnoreElementType() {
173       ignore_element_type_ = true;
174       return *this;
175     }
IgnoreFpPrecision()176     Equal& IgnoreFpPrecision() {
177       ignore_fp_precision_ = true;
178       return *this;
179     }
IgnoreDynamicDimension()180     Equal& IgnoreDynamicDimension() {
181       ignore_dynamic_dimension_ = true;
182       return *this;
183     }
184 
185    private:
186     bool ignore_layout_ = false;
187     bool ignore_tiles_in_layout_ = false;
188     bool ignore_element_size_in_layout_ = false;
189     bool ignore_element_type_ = false;
190     bool ignore_fp_precision_ = false;
191     bool ignore_dynamic_dimension_ = false;
192   };
193 
194   // Test that all fields of the shape are the same, equivalent to Equal().
195   bool operator==(const Shape& other) const { return Equal()(*this, other); }
196   bool operator!=(const Shape& other) const { return !(*this == other); }
197 
198  private:
199   // The element type of this shape (tuple, array, etc).
200   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
201 
202   // The array bounds of the dimensions. This is nonempty only for array
203   // shapes. For a dynamically-sized dimension, the respective value in this
204   // vector is an inclusive upper limit of the array bound.
205   std::vector<int64> dimensions_;
206 
207   // This vector is the same size as 'dimensions_' and indicates whether the
208   // respective dimension is dynamically sized.
209   std::vector<bool> dynamic_dimensions_;
210 
211   // The tuple element subshapes. This is nonempty only for tuple shapes.
212   std::vector<Shape> tuple_shapes_;
213 
214   // The layout of the shape. Only relevant for arrays.
215   Layout layout_;
216 };
217 
218 // Shape of the parameters and output of an XLA computation. This is analogous
219 // to a traditional function signature.
220 class ProgramShape {
221  public:
222   ProgramShape() = default;
223 
224   // Creates a ProgramShape from a ProgramShapeProto protobuf.
225   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
226 
227   // Returns a proto representation of the object.
228   ProgramShapeProto ToProto() const;
229 
230   string ToString() const;
231 
232   // The following methods mirror the protobuf generated code interface for the
233   // message ProgramShapeProto. This enabled easy migration of this data
234   // structure from a proto to a proper C++ class.
235   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
236   // interface.
237 
238   // Methods for accessing and manipulating the Shape of the parameters.
parameters_size()239   int parameters_size() const { return parameters_.size(); }
parameters(int index)240   const Shape& parameters(int index) const { return parameters_.at(index); }
mutable_parameters(int index)241   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
add_parameters()242   Shape* add_parameters() {
243     parameters_.emplace_back();
244     return &parameters_.back();
245   }
clear_parameters()246   void clear_parameters() { parameters_.clear(); }
parameters()247   const std::vector<Shape>& parameters() const { return parameters_; }
mutable_parameters()248   std::vector<Shape>* mutable_parameters() { return &parameters_; }
249 
250   // Methods for accessing and manipulating the Shape of the result.
result()251   const Shape& result() const { return result_; }
mutable_result()252   Shape* mutable_result() { return &result_; }
253 
254   // Methods for accessing and manipulating the names of the parameters.
parameter_names_size()255   int parameter_names_size() const { return parameter_names_.size(); }
parameter_names(int index)256   const string& parameter_names(int index) const {
257     return parameter_names_.at(index);
258   }
set_parameter_names(int index,const string & value)259   void set_parameter_names(int index, const string& value) {
260     parameter_names_.at(index) = value;
261   }
mutable_parameter_names(int index)262   string* mutable_parameter_names(int index) {
263     return &parameter_names_.at(index);
264   }
add_parameter_names(const string & value)265   void add_parameter_names(const string& value) {
266     parameter_names_.push_back(value);
267   }
add_parameter_names()268   string* add_parameter_names() {
269     parameter_names_.push_back("");
270     return &parameter_names_.back();
271   }
clear_parameter_names()272   void clear_parameter_names() { parameter_names_.clear(); }
parameter_names()273   const std::vector<string>& parameter_names() const {
274     return parameter_names_;
275   }
mutable_parameter_names()276   std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
277 
ShortDebugString()278   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()279   string DebugString() const { return ToProto().DebugString(); }
280 
281  private:
282   // The shapes of the parameters of the computation represented by this object.
283   std::vector<Shape> parameters_;
284 
285   // The names of the parameters of the computation represented by this object.
286   std::vector<string> parameter_names_;
287 
288   // The shape of the result of the computation represented by this object.
289   Shape result_;
290 };
291 
292 std::ostream& operator<<(std::ostream& out, const Shape& shape);
293 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
294 
295 }  // namespace xla
296 
297 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
298