1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 namespace internal {
40 
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44   // Data corresponding to this node.
45   std::pair<ShapeIndex, T> data;
46 
47   bool is_leaf = true;
48 
ShapeTreeNodeShapeTreeNode49   explicit ShapeTreeNode(ShapeIndex index)
50       : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51   ShapeTreeNode(ShapeIndex index, T data)
52       : data(std::move(index), std::move(data)) {}
53 };
54 
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57   // Index of the node in the ShapeTreeNode vector.
58   uint32 index;
59   // Index of the first child in a IndexTableEntry vector. In the index
60   // table all children entries for a given node will be placed next to each
61   // other. This allows us to use a single field to index them.
62   uint32 children_start;
63 #ifndef NDEBUG
64   // Number of children, used for bounds checking.
65   uint32 children_count;
66 #endif
67 };
68 
69 }  // namespace internal
70 
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73 template <typename ContainerType, typename IteratorType, typename ValueType>
74 class ShapeTreeLeafIterator;
75 
76 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
77 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
78 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
79 // type T.
80 //
81 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
82 // ShapeTree is an identically structured tree with data elements of type T at
83 // every node. I.e. the root is a tuple by definition, all interior nodes are
84 // also tuples, and all leaves are arrays.
85 //
86 // Like the Shape data structure, this is a tree and tuple elements cannot be
87 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
88 // object.
89 //
90 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
91 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
92 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor.  It's
93 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
94 // before its ShapeTree goes away.
95 template <typename T>
96 class ShapeTree {
97  public:
98   using Node = internal::ShapeTreeNode<T>;
99   using Index = internal::IndexTableEntry;
100 
101   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()102   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
103 
104   // Create ShapeTree with the given shape, and default-constructed T values for
105   // all nodes.
106   //
107   // The version that takes a pointer may be cheaper because it doesn't require
108   // any Shape copies, but then it's up to you to ensure that the pointer stays
109   // alive longer than this ShapeTree.
110   explicit ShapeTree(Shape shape);
111   explicit ShapeTree(const Shape* shape);
112   explicit ShapeTree(const std::shared_ptr<Shape>& shape);
113 
114   // Create ShapeTree with the given shape, and init_value for all nodes.
115   ShapeTree(Shape shape, const T& init_value);
116   ShapeTree(const Shape* shape, const T& init_value);
117   ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
118 
119   // Returns the data element associated with the array in the shape at the
120   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
121   const T& element(ShapeIndexView index) const;
122   T* mutable_element(ShapeIndexView index);
123 
124   // Return the shape represented with this ShapeTree.
shape()125   const Shape& shape() const { return *shape_; }
126 
127   // A ShapeTree object can own the underlying Shape pointer (via the
128   // shape_storage_ member), or can point to a Shape object owned by the caller.
129   // This API replaces the underlying Shape object to the one supplied by the
130   // caller, whom must ensure the object remain valid for the whole lifetime of
131   // this ShapeTree object, and also that the Shape is consistent with it.
replace_shape_ptr(const Shape * shape)132   void replace_shape_ptr(const Shape* shape) {
133     if (shape_storage_ != nullptr) {
134       DCHECK_EQ(*shape, *shape_storage_);
135       shape_storage_ = nullptr;
136     }
137     shape_ = shape;
138   }
139 
140   // Returns true if the node at the given index is a leaf node (an array
141   // shape).
IsLeaf(ShapeIndexView index)142   bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
143 
144   ShapeTree(const ShapeTree&) = default;
145   ShapeTree& operator=(const ShapeTree&) = default;
146   ShapeTree(ShapeTree&&) = default;
147   ShapeTree& operator=(ShapeTree&& other) = default;
148 
149   // iterator implements a bidirectional_iterator with
150   //  value_type = std::pair<ShapeIndex, T>.
151   //
152   // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
153   using iterator =
154       ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
155                         std::pair<ShapeIndex, T>>;
156   using const_iterator =
157       ShapeTreeIterator<const std::vector<Node>,
158                         typename std::vector<Node>::const_iterator,
159                         const std::pair<ShapeIndex, T>>;
160   using reverse_iterator = std::reverse_iterator<iterator>;
161   using const_reverse_iterator = std::reverse_iterator<const_iterator>;
162 
163   using leaf_iterator =
164       ShapeTreeLeafIterator<std::vector<Node>,
165                             typename std::vector<Node>::iterator,
166                             std::pair<ShapeIndex, T>>;
167   using const_leaf_iterator =
168       ShapeTreeLeafIterator<const std::vector<Node>,
169                             typename std::vector<Node>::const_iterator,
170                             const std::pair<ShapeIndex, T>>;
171   using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
172   using const_reverse_leaf_iterator =
173       std::reverse_iterator<const_leaf_iterator>;
174 
175   // begin/end for iterating over all nodes.
begin()176   iterator begin() { return iterator(&nodes_, nodes_.begin()); }
end()177   iterator end() { return iterator(&nodes_, nodes_.end()); }
begin()178   const_iterator begin() const {
179     return const_iterator(&nodes_, nodes_.begin());
180   }
end()181   const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); }
182 
183   // rbegin/rend for iterating over all nodes in reverse.
rbegin()184   reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()185   reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()186   const_reverse_iterator rbegin() const {
187     return const_reverse_iterator(end());
188   }
rend()189   const_reverse_iterator rend() const {
190     return const_reverse_iterator(begin());
191   }
192 
193   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
194   // children).
leaf_begin()195   leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); }
leaf_end()196   leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); }
leaf_begin()197   const_leaf_iterator leaf_begin() const {
198     return const_leaf_iterator(&nodes_, nodes_.begin());
199   }
leaf_end()200   const_leaf_iterator leaf_end() const {
201     return const_leaf_iterator(&nodes_, nodes_.end());
202   }
203   // range-based iterator for leaf_begin()/leaf_end().
leaves()204   tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
205     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
206   }
leaves()207   tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const {
208     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
209   }
210 
leaf_rbegin()211   reverse_leaf_iterator leaf_rbegin() {
212     return reverse_leaf_iterator(leaf_end());
213   }
leaf_rend()214   reverse_leaf_iterator leaf_rend() {
215     return reverse_leaf_iterator(leaf_begin());
216   }
leaf_rbegin()217   const_reverse_leaf_iterator leaf_rbegin() const {
218     return const_reverse_leaf_iterator(leaf_end());
219   }
leaf_rend()220   const_reverse_leaf_iterator leaf_rend() const {
221     return const_reverse_leaf_iterator(leaf_begin());
222   }
223 
224   // Returns an iterator pointing to the given ShapeIndex.
225   // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)226   iterator find(ShapeIndexView index) {
227     Node* element = Lookup(index);
228     auto element_iter = nodes_.begin() + (element - &nodes_[0]);
229     return iterator(&nodes_, element_iter);
230   }
find(ShapeIndexView index)231   const_iterator find(ShapeIndexView index) const {
232     const Node* element = Lookup(index);
233     auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
234     return const_iterator(&nodes_, element_iter);
235   }
236 
237   // Returns the number of leaf nodes in the tree.
leaf_count()238   int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
239 
240   // Recursively traverses the shape and calls the given function at each
241   // element. The function has the following arguments:
242   //
243   //   Fn :    A callable of type void(const ShapeIndex& index, const T& data)
244   //           (or compatible).
245   //   index : the index of the element in the shape. See ShapeUtil::GetSubshape
246   //           for definition of index.
247   //   data : The data value at this element.
248   template <typename Fn>
249   void ForEachElement(const Fn& func) const;
250 
251   // Like ForEachElement, but the callable has type
252   //
253   //   void (const ShapeIndex& index, T* data).
254   //
255   template <typename Fn>
256   void ForEachMutableElement(const Fn& func);
257 
258   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
259   // void.  The first non-OK return value is returned by the ForEach* function.
260   template <typename Fn>
261   Status ForEachElementWithStatus(const Fn& func) const;
262   template <typename Fn>
263   Status ForEachMutableElementWithStatus(const Fn& func);
264 
265   // Maps each element to generate a new tree with the same shape.
266   template <typename U>
Map(const std::function<U (const T &)> & func)267   ShapeTree<U> Map(const std::function<U(const T&)>& func) {
268     ShapeTree<U> result(shape_storage_);
269     ForEachElement([&](const ShapeIndex& index, const T& t) {
270       *result.mutable_element(index) = func(t);
271     });
272     return result;
273   }
274 
275   template <typename U>
Map(const std::function<U (T *)> & func)276   ShapeTree<U> Map(const std::function<U(T*)>& func) {
277     ShapeTree<U> result(shape_storage_);
278     ForEachMutableElement([&](const ShapeIndex& index, T* t) {
279       *result.mutable_element(index) = func(t);
280     });
281     return result;
282   }
283 
284   // Copy the subtree of values from 'other' rooted at ShapeIndex
285   // 'source_base_index' into the subtree of value in this ShapeTree rooted at
286   // 'target_base_index'.
287   //
288   // Precondition: The subshape of other.shape() at index source_base_index must
289   // be compatible with the subshape of shape() at index target_base_index.
290   void CopySubtreeFrom(const ShapeTree<T>& other,
291                        const ShapeIndex& source_base_index,
292                        const ShapeIndex& target_base_index);
293 
294   StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
295 
296   bool operator==(const ShapeTree<T>& other) const;
297   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
298 
299  private:
300   // Initialize node->children based on 'shape'. All children are assigned the
301   // the given 'init_value'.
302   void InitChildren(const Shape& shape, const T& init_value, Node* node,
303                     Index* index);
304 
305   // Initialize node->children based on 'shape'. All children have
306   // default-constructed data values.
307   void InitChildren(const Shape& shape, Node* node, Index* index);
308 
309   // Returns the number of subshapes, including interior nodes, in shape.
310   int64 CountSubshapes(const Shape& shape);
311 
312   // Helpers for traversing the shape via ForEachElement. The helpers
313   // recursively traverse the subtree rooted at "index" (defined as in
314   // ShapeUtil::GetSubshape).
315   template <typename Fn>
316   static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
317   template <typename Fn>
318   static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
319 
320   // Return the tree node at the given index.
321   Node* Lookup(ShapeIndexView index);
322   const Node* Lookup(ShapeIndexView index) const;
323 
324   // The nodes in this shape tree.
325   std::vector<Node> nodes_;
326 
327   // Index table for node lookups.
328   std::vector<Index> index_table_;
329 
330   // If we own our Shape, this field contains it, and shape_ is a pointer into
331   // here.  Otherwise if we don't own our shape, this is nullptr.
332   std::shared_ptr<Shape> shape_storage_;
333 
334   // The XLA shape mirrored in this ShapeTree.  This is either
335   // shape_storage_.get() or the Shape pointer passed to our constructor.
336   const Shape* shape_;
337 };
338 
339 // Internal iterator that performs a pre-order walk. This is cheap to copy.
340 // The iterator value_type is equivalent to a
341 // std::pair<ShapeIndex,T>&, similar to std::map.
342 template <typename ContainerType, typename IteratorType, typename ValueType>
343 class ShapeTreeIterator
344     : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
345  public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node)346   ShapeTreeIterator(ContainerType* nodes, IteratorType node)
347       : nodes_(nodes), node_(std::move(node)) {}
348 
349   ShapeTreeIterator& operator++() {
350     ++node_;
351     return *this;
352   }
353   ShapeTreeIterator operator++(int) {
354     auto i = *this;
355     ++(*this);
356     return i;
357   }
358 
359   ShapeTreeIterator& operator--() {
360     --node_;
361     return *this;
362   }
363   ShapeTreeIterator operator--(int) {
364     auto i = *this;
365     --(*this);
366     return i;
367   }
368 
369   bool operator==(const ShapeTreeIterator& other) const {
370     return node_ == other.node_;
371   }
372   bool operator!=(const ShapeTreeIterator& other) const {
373     return node_ != other.node_;
374   }
375   ValueType& operator*() const { return node_->data; }
376   ValueType* operator->() const { return &node_->data; }
377 
378  private:
379   ContainerType* nodes_;
380   IteratorType node_;
381 };
382 
383 // Internal iterator that performs a pre-order walk of the leaves. This is cheap
384 // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
385 // similar to std::map.
386 template <typename ContainerType, typename IteratorType, typename ValueType>
387 class ShapeTreeLeafIterator
388     : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
389  public:
ShapeTreeLeafIterator(ContainerType * nodes,IteratorType node)390   ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node)
391       : nodes_(nodes), node_(std::move(node)) {
392     while (node_ != nodes_->end() && !node_->is_leaf) {
393       ++node_;
394     }
395   }
396 
397   ShapeTreeLeafIterator& operator++() {
398     ++node_;
399     while (node_ != nodes_->end() && !node_->is_leaf) {
400       ++node_;
401     }
402     return *this;
403   }
404   ShapeTreeLeafIterator operator++(int) {
405     auto i = *this;
406     ++(*this);
407     return i;
408   }
409 
410   ShapeTreeLeafIterator& operator--() {
411     --node_;
412     while (node_ > nodes_->begin() && !node_->is_leaf) {
413       --node_;
414     }
415     return *this;
416   }
417   ShapeTreeLeafIterator operator--(int) {
418     auto i = *this;
419     --(*this);
420     return i;
421   }
422 
423   bool operator==(const ShapeTreeLeafIterator& other) const {
424     return node_ == other.node_;
425   }
426   bool operator!=(const ShapeTreeLeafIterator& other) const {
427     return node_ != other.node_;
428   }
429   ValueType& operator*() const { return node_->data; }
430   ValueType* operator->() const { return &node_->data; }
431 
432  private:
433   ContainerType* nodes_;
434   IteratorType node_;
435 };
436 
437 template <typename T>
CountSubshapes(const Shape & shape)438 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
439   int64 current_count = 1;
440   if (shape.IsTuple()) {
441     int64 count = ShapeUtil::TupleElementCount(shape);
442     for (int i = 0; i < count; ++i) {
443       current_count += CountSubshapes(shape.tuple_shapes(i));
444     }
445   }
446   return current_count;
447 }
448 
449 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)450 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
451                                 Node* node, Index* index) {
452   if (shape.IsTuple()) {
453     const int64 size = ShapeUtil::TupleElementCount(shape);
454 #ifndef NDEBUG
455     index->children_count = size;
456 #endif
457     node->is_leaf = false;
458     ShapeIndex shape_index = node->data.first;
459     shape_index.push_back(0);
460 
461     // At the end of the index_table, reserve a continuous space to hold the
462     // children of current node. In order to enforce the invariant that all
463     // children of a given node are placed together, we need to do the
464     // reservation before we recurse into any of its children.
465     int64 children_start_position = index_table_.size();
466     index_table_.resize(index_table_.size() + size);
467 
468     for (int i = 0; i < size; ++i) {
469       shape_index[shape_index.size() - 1] = i;
470       index_table_[children_start_position + i].index = nodes_.size();
471       // The first child of the node in the index table is placed at the end of
472       // the table.
473       index_table_[children_start_position + i].children_start =
474           index_table_.size();
475       nodes_.emplace_back(shape_index, init_value);
476       InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
477                    &index_table_[children_start_position + i]);
478     }
479   } else {
480 #ifndef NDEBUG
481     index->children_count = 0;
482 #endif
483   }
484 }
485 
486 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)487 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
488   if (shape.IsTuple()) {
489     const int64 size = ShapeUtil::TupleElementCount(shape);
490 #ifndef NDEBUG
491     index->children_count = size;
492 #endif
493     node->is_leaf = false;
494     ShapeIndex shape_index = node->data.first;
495     shape_index.push_back(0);
496 
497     // At the end of the index_table, reserve a continuous space to hold the
498     // children of current node. In order to enforce the invariant that all
499     // children of a given node are placed together, we need to do the
500     // reservation before we recurse into any of its children.
501     int64 children_start_position = index_table_.size();
502     index_table_.resize(index_table_.size() + size);
503 
504     for (int i = 0; i < size; ++i) {
505       shape_index[shape_index.size() - 1] = i;
506       index_table_[children_start_position + i].index = nodes_.size();
507       // The first child of the node in the index table is placed at the end of
508       // the table.
509       index_table_[children_start_position + i].children_start =
510           index_table_.size();
511       nodes_.emplace_back(shape_index);
512       InitChildren(shape.tuple_shapes(i), &nodes_.back(),
513                    &index_table_[children_start_position + i]);
514     }
515   } else {
516 #ifndef NDEBUG
517     index->children_count = 0;
518 #endif
519   }
520 }
521 
522 template <typename T>
ShapeTree(Shape shape)523 ShapeTree<T>::ShapeTree(Shape shape)
524     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
525       shape_(shape_storage_.get()) {
526   const int64 count = CountSubshapes(*shape_);
527   nodes_.reserve(count);
528   nodes_.emplace_back(ShapeIndex{});
529 
530   index_table_.reserve(count);
531   index_table_.emplace_back(Index{0, 1});
532   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
533 }
534 
535 template <typename T>
ShapeTree(const Shape * shape)536 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
537   const int64 count = CountSubshapes(*shape_);
538   nodes_.reserve(count);
539   nodes_.emplace_back(ShapeIndex{});
540 
541   index_table_.reserve(count);
542   index_table_.emplace_back(Index{0, 1});
543   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
544 }
545 
546 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)547 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
548     : shape_storage_(shape), shape_(shape_storage_.get()) {
549   const int64 count = CountSubshapes(*shape_);
550   nodes_.reserve(count);
551   nodes_.emplace_back(ShapeIndex{});
552 
553   index_table_.reserve(count);
554   index_table_.emplace_back(Index{0, 1});
555   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
556 }
557 
558 template <typename T>
ShapeTree(Shape shape,const T & init_value)559 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
560     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
561       shape_(shape_storage_.get()) {
562   const int64 count = CountSubshapes(*shape_);
563   nodes_.reserve(count);
564   nodes_.emplace_back(ShapeIndex{}, init_value);
565 
566   index_table_.reserve(count);
567   index_table_.emplace_back(Index{0, 1});
568   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
569 }
570 
571 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)572 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
573     : shape_(shape) {
574   const int64 count = CountSubshapes(*shape_);
575   nodes_.reserve(count);
576   nodes_.emplace_back(ShapeIndex{}, init_value);
577 
578   index_table_.reserve(count);
579   index_table_.emplace_back(Index{0, 1});
580   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
581 }
582 
583 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)584 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
585                         const T& init_value)
586     : shape_storage_(shape), shape_(shape_storage_.get()) {
587   const int64 count = CountSubshapes(*shape_);
588   nodes_.reserve(count);
589   nodes_.emplace_back(ShapeIndex{}, init_value);
590 
591   index_table_.reserve(count);
592   index_table_.emplace_back(Index{0, 1});
593   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
594 }
595 
596 template <typename T>
element(ShapeIndexView index)597 const T& ShapeTree<T>::element(ShapeIndexView index) const {
598   return Lookup(index)->data.second;
599 }
600 
601 template <typename T>
mutable_element(ShapeIndexView index)602 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
603   return &Lookup(index)->data.second;
604 }
605 
606 template <typename T>
Lookup(ShapeIndexView index)607 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
608   Index* iter = &index_table_[0];
609   for (const int64 i : index) {
610     CHECK_GE(i, 0);
611 #ifndef NDEBUG
612     CHECK_LT(i, iter->children_count);
613 #endif
614     iter = &index_table_[iter->children_start + i];
615   }
616 
617   return &nodes_[iter->index];
618 }
619 
620 template <typename T>
Lookup(ShapeIndexView index)621 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
622     ShapeIndexView index) const {
623   return const_cast<ShapeTree*>(this)->Lookup(index);
624 }
625 
626 /* static */
627 template <typename T>
628 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)629 Status ShapeTree<T>::ForEachHelper(const Fn& func,
630                                    const std::vector<Node>& nodes) {
631   for (const auto& node : nodes) {
632     TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
633   }
634   return Status::OK();
635 }
636 
637 /* static */
638 template <typename T>
639 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)640 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
641                                           std::vector<Node>* nodes) {
642   for (auto& node : *nodes) {
643     TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
644   }
645   return Status::OK();
646 }
647 
648 template <typename T>
649 template <typename Fn>
ForEachElementWithStatus(const Fn & func)650 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
651   return ForEachHelper(func, nodes_);
652 }
653 
654 template <typename T>
655 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)656 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
657   return ForEachMutableHelper(func, &nodes_);
658 }
659 
660 template <typename T>
661 template <typename Fn>
ForEachElement(const Fn & func)662 void ShapeTree<T>::ForEachElement(const Fn& func) const {
663   return ForEachHelper(
664              [&func](const ShapeIndex& index, const T& data) {
665                func(index, data);
666                return Status::OK();
667              },
668              nodes_)
669       .IgnoreError();
670 }
671 
672 template <typename T>
673 template <typename Fn>
ForEachMutableElement(const Fn & func)674 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
675   return ForEachMutableHelper(
676              [&func](const ShapeIndex& index, T* data) {
677                func(index, data);
678                return Status::OK();
679              },
680              &nodes_)
681       .IgnoreError();
682 }
683 
684 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)685 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
686                                    const ShapeIndex& source_base_index,
687                                    const ShapeIndex& target_base_index) {
688   CHECK(ShapeUtil::Compatible(
689       ShapeUtil::GetSubshape(shape(), target_base_index),
690       ShapeUtil::GetSubshape(other.shape(), source_base_index)))
691       << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
692       << ShapeUtil::GetSubshape(other.shape(), source_base_index);
693   ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
694                             const ShapeIndex& index, T* data) {
695     // Copy the data element only if index is in the
696     // subtree rooted at target_base_index.
697     for (int i = 0; i < target_base_index.size(); ++i) {
698       if (i >= index.size() || index[i] != target_base_index[i]) {
699         return;
700       }
701     }
702     // Construct source element index to copy from.
703     ShapeIndex source_index = source_base_index;
704     for (int i = target_base_index.size(); i < index.size(); ++i) {
705       source_index.push_back(index[i]);
706     }
707     *data = other.element(source_index);
708   });
709 }
710 
711 template <typename T>
SubShapeTree(const ShapeIndex & index)712 StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
713     const ShapeIndex& index) const {
714   TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
715                       ShapeUtil::TryGetSubshape(shape(), index));
716   ShapeTree<T> sub_shape_tree(*sub_shape);
717   sub_shape_tree.CopySubtreeFrom(*this, index, {});
718   return std::move(sub_shape_tree);
719 }
720 
721 template <typename T>
722 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
723   bool equal = true;
724   ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
725     if (data != other.element(index)) {
726       equal = false;
727     }
728   });
729   return equal;
730 }
731 
732 }  // namespace xla
733 
734 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
735