1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18 
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21 
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24 
25 #include <memory>
26 #include <stdexcept>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/hash/hash.h"
32 #include "absl/memory/memory.h"
33 #include "pybind11/pybind11.h"
34 #include "pybind11/pytypes.h"
35 #include "pybind11/stl.h"
36 
37 namespace xla {
38 
39 // Registry of custom node types.
40 class CustomNodeRegistry {
41  public:
42   struct Registration {
43     // The Python type object, used to identify the type.
44     pybind11::object type;
45     // A function with signature: object -> (iterable, aux_data)
46     pybind11::function to_iterable;
47     // A function with signature: (aux_data, iterable) -> object
48     pybind11::function from_iterable;
49   };
50 
51   // Registers a new custom type. Objects of `type` will be treated as container
52   // node types in PyTrees.
53   static void Register(pybind11::object type, pybind11::function to_iterable,
54                        pybind11::function from_iterable);
55 
56   // Finds the custom type registration for `type`. Returns nullptr if none
57   // exists.
58   static const Registration* Lookup(pybind11::handle type);
59 
60  private:
61   static CustomNodeRegistry* Singleton();
62 
63   struct TypeHash {
operatorTypeHash64     size_t operator()(const pybind11::object& t) const {
65       return pybind11::hash(t);
66     }
67   };
68   struct TypeEq {
operatorTypeEq69     bool operator()(const pybind11::object& a,
70                     const pybind11::object& b) const {
71       return a.equal(b);
72     }
73   };
74   absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
75                       TypeEq>
76       registrations_;
77 };
78 
79 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
80 // Python values, where the interior nodes are tuples, lists, dictionaries, or
81 // user-defined containers, and the leaves are other objects.
82 class PyTreeDef {
83  public:
84   PyTreeDef() = default;
85 
86   // Flattens a Pytree into a list of leaves and a PyTreeDef.
87   static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
88   Flatten(pybind11::handle x,
89           absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
90 
91   // Recursive helper used to implement Flatten().
92   void FlattenInto(
93       pybind11::handle handle, std::vector<pybind11::object>& leaves,
94       absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
95 
96   // Tests whether the given list is a flat list of leaves.
97   static bool AllLeaves(const pybind11::iterable& x);
98 
99   // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
100   // the tree-structure of 'x'. For example, if we flatten a value
101   // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
102   // list of leaves [1, (2, 3), {"foo": 4}].
103   pybind11::list FlattenUpTo(pybind11::handle x) const;
104 
105   // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
106   pybind11::object Unflatten(pybind11::iterable leaves) const;
107 
108   // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
109   // `inner`.
110   std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
111 
112   // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
113   static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
114 
115   std::vector<std::unique_ptr<PyTreeDef>> Children() const;
116 
117   // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
118   // f_node to each container node.
119   // TODO(phawkins): use flattening everywhere instead and delete this method.
120   pybind11::object Walk(const pybind11::function& f_node,
121                         pybind11::handle f_leaf,
122                         pybind11::iterable leaves) const;
123 
124   // Given a tree of iterables with the same node/leaf structure as this PyTree,
125   // build the corresponding PyTree.
126   // TODO(phawkins): use flattening everywhere instead and delete this method.
127   pybind11::object FromIterableTree(pybind11::handle xs) const;
128 
num_leaves()129   int num_leaves() const {
130     if (traversal_.empty()) {
131       return 0;
132     }
133     return traversal_.back().num_leaves;
134   }
135 
num_nodes()136   int num_nodes() const { return traversal_.size(); }
137 
138   size_t Hash() const;
139 
140   bool operator==(const PyTreeDef& other) const;
141   bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
142 
143   std::string ToString() const;
144 
145  private:
146   enum class Kind {
147     kLeaf,        // An opaque leaf node
148     kNone,        // None.
149     kTuple,       // A tuple
150     kNamedTuple,  // A collections.namedtuple
151     kList,        // A list
152     kDict,        // A dict
153     kCustom,      // A custom type.
154   };
155 
156   struct Node {
157     Kind kind = Kind::kLeaf;
158 
159     // Arity for non-kLeaf types.
160     int arity = 0;
161 
162     // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
163     // object. For a kDict, contains a sorted list of keys. For a kCustom type,
164     // contains the auxiliary data returned by the `to_iterable` function.
165     pybind11::object node_data;
166 
167     const CustomNodeRegistry::Registration* custom = nullptr;
168 
169     // Number of leaf nodes in the subtree rooted at this node.
170     int num_leaves = 0;
171 
172     // Number of leaf and interior nodes in the subtree rooted at this node.
173     int num_nodes = 0;
174   };
175   template <typename H>
176   friend H AbslHashValue(H h, const Node& n);
177 
178   template <typename H>
179   friend H AbslHashValue(H h, const PyTreeDef& t);
180 
181   // Helper that manufactures an instance of a node given its children.
182   static pybind11::object MakeNode(const Node& node,
183                                    absl::Span<pybind11::object> children);
184 
185   // Recursive helper used to implement FromIterableTree()
186   pybind11::object FromIterableTreeHelper(
187       pybind11::handle xs,
188       std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const;
189 
190   // Computes the node kind of a given Python object.
191   static Kind GetKind(const pybind11::handle& obj,
192                       CustomNodeRegistry::Registration const** custom);
193 
194   // Nodes, in a post-order traversal. We use an ordered traversal to minimize
195   // allocations, and post-order corresponds to the order we need to rebuild the
196   // tree structure.
197   std::vector<Node> traversal_;
198 };
199 
200 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)201 H AbslHashValue(H h, const PyTreeDef::Node& n) {
202   h = H::combine(std::move(h), n.kind, n.arity, n.custom);
203   return h;
204 }
205 
206 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)207 H AbslHashValue(H h, const PyTreeDef& t) {
208   return H::combine_contiguous(std::move(h), t.traversal_.data(),
209                                t.traversal_.size());
210 }
211 
212 void BuildPytreeSubmodule(pybind11::module& m);
213 
214 }  // namespace xla
215 
216 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
217