1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
18 
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
23 #include "tensorflow/compiler/xla/python/py_client.h"
24 #include "tensorflow/compiler/xla/python/pytree.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace jax {
29 
30 // Describes the abstract shape and dtype of an argument.
31 struct ArgSignature {
ArgSignatureArgSignature32   ArgSignature(xla::PrimitiveType dtype, absl::Span<const xla::int64> shape,
33                bool weak_type)
34       : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
35   // This is the XLA dtype of the object.
36   const xla::PrimitiveType dtype;
37   const absl::InlinedVector<xla::int64, 4> shape;
38   // JAX arguments can be of weak type, if and only if they are Python scalars
39   // or `DeviceArray` values such that `aval.weak_type` is true.
40   const bool weak_type;
41   bool operator==(const ArgSignature& other) const {
42     return std::tie(dtype, weak_type, shape) ==
43            std::tie(other.dtype, other.weak_type, other.shape);
44   }
45   bool operator!=(const ArgSignature& other) const { return !(*this == other); }
46   std::string DebugString() const;
47 };
48 
49 template <typename H>
AbslHashValue(H h,const ArgSignature & s)50 H AbslHashValue(H h, const ArgSignature& s) {
51   h = H::combine(std::move(h), s.dtype);
52   h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size());
53   return h;
54 }
55 
56 // The signature of Python jitted function call, partitioned into:
57 // - dynamic positional arguments (i.e. positional args which are not static)
58 // - static positional arguments (i.e. the args associated to static_argnums)
59 // - keyword arguments
60 // The CallSignature should unambiguously identify a function call, thus,
61 // equality is based on:
62 // (a) Same PyTree for all dynamic positional arguments and keyword arguments
63 // (a) equality of the arguments and keyword arguments ArgSignature
64 // (a) equality (delegated to Python) of the static arguments.
65 struct CallSignature {
66   struct KwargEntry {
67     // To avoid comparing strings, we intern the kwargs strings.
68     // The compilation cache holds a reference to all the keys.
69     pybind11::handle key;
70     xla::PyTreeDef value_treedef;
71     bool operator==(const KwargEntry& other) const {
72       return key.ptr() == other.key.ptr() &&
73              value_treedef == other.value_treedef;
74     }
75     bool operator!=(const KwargEntry& other) const { return !(*this == other); }
76   };
77 
78   // Only contains the arguments associated to `static_argnums`, sorted in the
79   // order of their argnum index.
80   std::vector<pybind11::object> static_args;
81   // A PyTreeDef for each positional dynamic (i.e. not static) argument.
82   std::vector<xla::PyTreeDef> dynamic_positional_args_treedef;
83   // Keyword arguments. Sorted by the keyword name.
84   std::vector<KwargEntry> keyword_args;
85   // Shape and dtype for both the dynamic positional arguments and the keyword
86   // arguments (sorted by keyword name).
87   std::vector<ArgSignature> dynamic_args_signatures;
88   xla::PjRtDevice* device;
89 
90   bool operator==(const CallSignature& other) const;
91   bool operator!=(const CallSignature& other) const {
92     return !(*this == other);
93   }
94 
95   // To be used when we want to keep ownership of Python values referenced by
96   // the `CallSignature` (i.e. when we insert an entry).
97   void IncRef() const;
98   // The destructor of the cache should call this on all entries.
99   void DecRef() const;
100 
101   std::string DebugString() const;
102 };
103 
104 template <typename H>
AbslHashValue(H h,const CallSignature::KwargEntry & kw)105 H AbslHashValue(H h, const CallSignature::KwargEntry& kw) {
106   h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef);
107   return h;
108 }
109 
110 template <typename H>
111 H AbslHashValue(H h, const CallSignature& s);
112 
113 // The resulting information of the parsing and conversion of the arguments.
114 struct ParsedArgumentsAsBuffers {
115   // The call signature will be filled during 2 steps:
116   // - `ParseArguments` will fill the static arguments and the pytree
117   //    structures
118   // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
119   CallSignature signature;
120   // The concatenation of the dynamic positional arguments and the sorted
121   // keyword arguments.
122   std::vector<pybind11::object> flat_dynamic_args;
123   std::vector<pybind11::object> keep_alive_objects;
124 
125   // The following is only valid if the parsing succeeds.
126   std::vector<xla::PjRtBuffer*> arg_buffers;
127   // We may need to keep these objects around, because:
128   // (a) we need to extend the lifetime of objects created within
129   //    `ConvertArgsToBuffers`
130   // (b) `arg_buffers` do not maintain ownership
131   std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
132 };
133 
134 // Filter out static arguments, flatten and concatenate other arguments (i.e.
135 // dynamic positional and keyword arguments), filling `arguments` in place.
136 xla::Status ParseArguments(const pybind11::args& args,
137                            const pybind11::kwargs& py_kwargs,
138                            absl::Span<int const> static_argnums,
139                            ParsedArgumentsAsBuffers& arguments);
140 
141 struct DevicePutResult {
DevicePutResultDevicePutResult142   explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type)
143       : buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
DevicePutResultDevicePutResult144   DevicePutResult(std::unique_ptr<xla::PjRtBuffer> new_buffer, bool weak_type)
145       : buffer(new_buffer.get()),
146         weak_type(weak_type),
147         owned_buffer(std::move(new_buffer)) {}
148 
149   xla::PjRtBuffer* buffer;
150   bool weak_type;
151   std::unique_ptr<xla::PjRtBuffer> owned_buffer;
152 };
153 
154 // Returns the ArgSignature associated with an argument. Returns an error if
155 // the argument is not supported.
156 xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
157                                                 bool jax_enable_x64);
158 
159 // Moves a device-like object to be on device.
160 // - If the object is already on device, `owned_buffer` will be nullptr.
161 // - If it's not, a new buffer will be created and returned using
162 //   `owned_buffer`.
163 // In all cases, `buffer` will point to the already existing or newly created
164 // buffer.
165 // If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be
166 // returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
167 // supported yet.
168 xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
169                                          xla::PjRtDevice* to_device,
170                                          bool jax_enable_x64,
171                                          xla::PyClient& pyclient);
172 
173 // The function to call in `xla.cc` to add the bindings for this module.
174 void BuildJaxjitSubmodule(pybind11::module& m);
175 
176 }  // namespace jax
177 
178 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
179