1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26 
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28 
29 #include <Python.h>
30 
31 #include <exception>
32 #include <memory>
33 #include <stdexcept>
34 #include <utility>
35 
36 #include "absl/container/flat_hash_map.h"
37 #include "absl/container/inlined_vector.h"
38 #include "absl/strings/str_cat.h"
39 #include "absl/synchronization/notification.h"
40 #include "absl/types/optional.h"
41 #include "pybind11/cast.h"
42 #include "pybind11/numpy.h"
43 #include "pybind11/pybind11.h"
44 #include "pybind11/pytypes.h"
45 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
46 #include "tensorflow/compiler/xla/python/py_buffer.h"
47 #include "tensorflow/compiler/xla/python/py_executable.h"
48 #include "tensorflow/compiler/xla/python/pytree.h"
49 #include "tensorflow/compiler/xla/python/types.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/statusor.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/platform/status.h"
56 
57 namespace jax {
58 
59 namespace py = pybind11;
60 
61 // TODO(phawkins): Add support for Tracers.
62 // TODO(jblespiau): Use absl Status.
63 // TODO(jblespiau): Remove the "xla::" prefixes when not needed.
64 
DebugString() const65 std::string ArgSignature::DebugString() const {
66   std::string result = "";
67   if (weak_type) {
68     absl::StrAppend(&result, "weak_");
69   }
70   absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
71   absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
72   return result;
73 }
74 
operator ==(const CallSignature & other) const75 bool CallSignature::operator==(const CallSignature& other) const {
76   return std::tie(dynamic_positional_args_treedef, keyword_args,
77                   dynamic_args_signatures, device) ==
78              std::tie(other.dynamic_positional_args_treedef, other.keyword_args,
79                       other.dynamic_args_signatures, other.device) &&
80          // `==` on py:objects is the Python `is`. We need equal.
81          std::equal(
82              static_args.begin(), static_args.end(), other.static_args.begin(),
83              other.static_args.end(),
84              [](const py::object& a, const py::object& b) {
85                try {
86                  return a.equal(b);
87                } catch (const py::error_already_set& e) {
88                  throw std::invalid_argument(absl::StrCat(
89                      "static arguments should be comparable using __eq__."
90                      "The following error was raised when comparing two "
91                      "objects of types ",
92                      py::cast<std::string>(py::str(py::type::of(a))), " and ",
93                      py::cast<std::string>(py::str(py::type::of(b))),
94                      ". The error was:\n", e.what()));
95                }
96              });
97 }
98 
IncRef() const99 void CallSignature::IncRef() const {
100   for (const auto& kw : keyword_args) {
101     kw.key.inc_ref();
102   }
103 }
104 
DecRef() const105 void CallSignature::DecRef() const {
106   for (const auto& kw : keyword_args) {
107     kw.key.dec_ref();
108   }
109 }
110 
111 namespace {
112 
113 thread_local bool disable_jit;
SetDisableJit(bool disable_jit_)114 void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
GetDisableJit()115 bool GetDisableJit() { return disable_jit; }
116 
117 }  // namespace
118 
DebugString() const119 std::string CallSignature::DebugString() const {
120   std::vector<std::string> static_args_str;
121   static_args_str.reserve(static_args.size());
122   for (auto& static_arg : static_args) {
123     static_args_str.emplace_back(py::cast<std::string>(py::str(static_arg)));
124   }
125 
126   std::vector<std::string> signature_str;
127   signature_str.reserve(dynamic_args_signatures.size());
128 
129   for (auto& arg_signature : dynamic_args_signatures) {
130     signature_str.emplace_back(arg_signature.DebugString());
131   }
132   std::vector<std::string> tree_def_str;
133   signature_str.reserve(dynamic_positional_args_treedef.size());
134   for (auto& tree_def : dynamic_positional_args_treedef) {
135     tree_def_str.emplace_back(tree_def.ToString());
136   }
137   std::vector<std::string> keyword_names;
138   keyword_names.reserve(keyword_args.size());
139   for (auto& kwarg_entry : keyword_args) {
140     keyword_names.emplace_back(py::cast<std::string>(kwarg_entry.key));
141     tree_def_str.emplace_back(kwarg_entry.value_treedef.ToString());
142   }
143   return absl::StrCat(
144       static_args.size(), " static_args: ", absl::StrJoin(static_args_str, ","),
145       "\n",  // new line
146       keyword_args.size(), " keyword args:", absl::StrJoin(keyword_names, ","),
147       "\n",  // new-line
148       dynamic_positional_args_treedef.size(), " positional args.\n",
149       dynamic_args_signatures.size(),
150       " dynamic args (positional+keyword):\n   - ",
151       absl::StrJoin(signature_str, ", "), "\n   - ",
152       absl::StrJoin(tree_def_str, " | "));
153 }
154 
155 template <typename H>
AbslHashValue(H h,const CallSignature & s)156 H AbslHashValue(H h, const CallSignature& s) {
157   h = H::combine_contiguous(std::move(h),
158                             s.dynamic_positional_args_treedef.data(),
159                             s.dynamic_positional_args_treedef.size());
160   h = H::combine_contiguous(std::move(h), s.keyword_args.data(),
161                             s.keyword_args.size());
162   h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(),
163                             s.dynamic_args_signatures.size());
164   h = H::combine(std::move(h), s.device);
165   for (const auto& static_arg : s.static_args) {
166     ssize_t hash;
167     try {
168       hash = py::hash(static_arg);
169     } catch (const py::error_already_set& e) {
170       throw std::invalid_argument(absl::StrCat(
171           "Non-hashable static arguments are not supported. An error occured "
172           "while trying to hash an object of type ",
173           py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
174           py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
175           e.what(), "\n"));
176     }
177     h = H::combine(std::move(h), hash);
178   }
179   return h;
180 }
181 
182 // Filter out static arguments, flatten and concatenate other arguments (i.e.
183 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(const py::args & args,const py::kwargs & py_kwargs,absl::Span<int const> static_argnums,ParsedArgumentsAsBuffers & arguments)184 xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
185                            absl::Span<int const> static_argnums,
186                            ParsedArgumentsAsBuffers& arguments) {
187   if (static_argnums.size() > args.size()) {
188     return xla::InvalidArgument(
189         "%s", "[jaxjit] Error with static argnums, executing the Python path.");
190   }
191   arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
192                                       static_argnums.size());
193   arguments.signature.dynamic_positional_args_treedef.reserve(
194       args.size() - static_argnums.size());
195 
196   // Positional arguments.
197   for (size_t i = 0; i < args.size(); ++i) {
198     if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
199         static_argnums.end()) {
200       xla::PyTreeDef pytree_def;
201       pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
202       arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
203     } else {
204       arguments.signature.static_args.emplace_back(
205           // borrow is mandatory here.
206           py::reinterpret_borrow<py::object>(args[i]));
207     }
208   }
209 
210   // Keyword arguments.
211   std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
212                                                         py_kwargs.end());
213   // We first intern the keys, then sort them (by name, as in the Python path)
214   // (see also xla::PyTreeDef::Flatten) and then create the signatures.
215   // TODO(jblespiau): We should be able to sort the keys by interned-key
216   // pointers, but this requires the Python compilation to do the same.
217   arguments.signature.keyword_args.resize(kwargs.size());
218   for (size_t i = 0; i < kwargs.size(); ++i) {
219     // Intern the key if not already interned.
220     if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
221       PyObject* key = kwargs[i].first.ptr();
222       kwargs[i].first.inc_ref();
223       PyUnicode_InternInPlace(&key);
224       arguments.keep_alive_objects.push_back(
225           py::reinterpret_steal<py::object>(key));
226       kwargs[i].first = py::handle(key);
227     }
228   }
229 
230   std::sort(kwargs.begin(), kwargs.end(),
231             [](const std::pair<py::handle, py::handle>& a,
232                const std::pair<py::handle, py::handle>& b) {
233               return a.first < b.first;
234             });
235   for (size_t i = 0; i < kwargs.size(); ++i) {
236     arguments.signature.keyword_args[i].key = kwargs[i].first;
237     arguments.signature.keyword_args[i].value_treedef.FlattenInto(
238         kwargs[i].second, arguments.flat_dynamic_args);
239   }
240   return xla::Status::OK();
241 }
242 
243 namespace {
244 
245 struct NumpyScalarTypes {
246   py::object np_bool;
247   py::object np_int8;
248   py::object np_int16;
249   py::object np_int32;
250   py::object np_int64;
251   py::object np_uint8;
252   py::object np_uint16;
253   py::object np_uint32;
254   py::object np_uint64;
255   py::object np_float16;
256   py::object np_float32;
257   py::object np_float64;
258   py::object np_complex64;
259   py::object np_complex128;
260   py::object np_longlong;
261   py::object np_intc;
262 };
263 
GetNumpyScalarTypes()264 const NumpyScalarTypes& GetNumpyScalarTypes() {
265   static const NumpyScalarTypes* singleton = []() {
266     // Use Designated initializers when they are available.
267     const auto numpy = py::module::import("numpy");
268     NumpyScalarTypes* dtypes = new NumpyScalarTypes();
269     dtypes->np_bool = py::object(numpy.attr("bool_"));
270     dtypes->np_int8 = py::object(numpy.attr("int8"));
271     dtypes->np_int16 = py::object(numpy.attr("int16"));
272     dtypes->np_int32 = py::object(numpy.attr("int32"));
273     dtypes->np_int64 = py::object(numpy.attr("int64"));
274     dtypes->np_uint8 = py::object(numpy.attr("uint8"));
275     dtypes->np_uint16 = py::object(numpy.attr("uint16"));
276     dtypes->np_uint32 = py::object(numpy.attr("uint32"));
277     dtypes->np_uint64 = py::object(numpy.attr("uint64"));
278     dtypes->np_float16 = py::object(numpy.attr("float16"));
279     dtypes->np_float32 = py::object(numpy.attr("float32"));
280     dtypes->np_float64 = py::object(numpy.attr("float64"));
281     dtypes->np_complex64 = py::object(numpy.attr("complex64"));
282     dtypes->np_complex128 = py::object(numpy.attr("complex128"));
283     dtypes->np_longlong = py::object(numpy.attr("longlong"));
284     dtypes->np_intc = py::object(numpy.attr("intc"));
285 
286     return dtypes;
287   }();
288 
289   return *singleton;
290 }
291 
DtypeTo32BitDtype(const py::dtype & dtype)292 const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
293   // TODO(jblespiau): Use GetNumpyScalarTypes instead.
294   static const auto* int64_dt = new py::dtype("int64");
295   static const auto* int32_dt = new py::dtype("int32");
296   static const auto* uint64_dt = new py::dtype("uint64");
297   static const auto* uint32_dt = new py::dtype("uint32");
298   static const auto* float64_dt = new py::dtype("float64");
299   static const auto* float32_dt = new py::dtype("float32");
300   static const auto* complex64_dt = new py::dtype("complex64");
301   static const auto* complex128_dt = new py::dtype("complex128");
302 
303   if (dtype.equal(*int64_dt)) {
304     return int32_dt;
305   }
306   if (dtype.equal(*float64_dt)) {
307     return float32_dt;
308   }
309   if (dtype.equal(*uint64_dt)) {
310     return uint32_dt;
311   }
312   if (dtype.equal(*complex128_dt)) {
313     return complex64_dt;
314   }
315 
316   return nullptr;
317 }
318 
319 // The equivalent of the Python jax/lazy.py::is_trivial:
320 // return (type(lexpr.input) is ArrayVar and
321 //         lexpr.dims == tuple(range(len(lexpr.shape))))
322 //
323 // Expects *only* `None` or a LazyExpr` object.
IsTrivialLazyExpr(py::handle lexpr)324 bool IsTrivialLazyExpr(py::handle lexpr) {
325   if (lexpr.is_none()) {
326     return true;
327   }
328 
329   static const auto* lazy_module =
330       new py::module(py::module::import("jax.lazy"));
331   auto input = py::getattr(lexpr, "input");
332   if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
333     return false;
334   }
335   py::tuple dims = py::cast<py::tuple>(lexpr.attr("dims"));
336   py::tuple shape = py::cast<py::tuple>(lexpr.attr("shape"));
337 
338   for (int i = 0; i < shape.size(); ++i) {
339     if (dims[i].is_none()) {
340       return false;
341     }
342     if (py::cast<int>(dims[i]) != i) {
343       return false;
344     }
345   }
346   return true;
347 }
348 
IsFloat0(py::array arg)349 bool IsFloat0(py::array arg) {
350   static const auto* dtypes_module =
351       new py::module(py::module::import("jax.dtypes"));
352   static const auto* float0_dtype =
353       new py::handle(dtypes_module->attr("float0"));
354   return float0_dtype->is(arg.attr("dtype"));
355 }
356 
357 template <typename CppType, typename Pybind11Type>
ConvertToScalarBuffer(const py::handle & scalar,xla::PjRtClient * client,xla::PjRtDevice * device)358 std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
359     const py::handle& scalar, xla::PjRtClient* client,
360     xla::PjRtDevice* device) {
361   CppType data = py::cast<Pybind11Type>(scalar);
362   // Work around for https://github.com/pybind/pybind11/issues/2786
363   if (PyErr_Occurred()) {
364     throw py::error_already_set();
365   }
366   xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
367   return ValueOrThrow(client->BufferFromHostBuffer(
368       &data, shape,
369       xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
370       device));
371 }
372 
373 }  // namespace
374 
375 namespace {
376 
377 using ToArgSignatureHandler =
378     std::function<xla::StatusOr<ArgSignature>(py::handle, bool)>;
379 }
380 
ArgSignatureOfValue(pybind11::handle arg,bool jax_enable_x64)381 xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
382                                                 bool jax_enable_x64) {
383   static const absl::flat_hash_map<PyObject*,
384                                    ToArgSignatureHandler>* const handlers = [] {
385     auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
386 
387     const auto xla_module = py::module::import("jax.interpreters.xla");
388     const auto& device_array = xla_module.attr("_DeviceArray");
389 
390     const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
391 
392     // The 4 Python native types.
393     ToArgSignatureHandler bool_handler =
394         [](py::handle, bool) -> xla::StatusOr<ArgSignature> {
395       return ArgSignature(xla::PrimitiveType::PRED, {}, true);
396     };
397     ToArgSignatureHandler int_handler =
398         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
399       if (jax_enable_x64) {
400         return ArgSignature(xla::PrimitiveType::S64, {}, true);
401       } else {
402         return ArgSignature(xla::PrimitiveType::S32, {}, true);
403       }
404     };
405     ToArgSignatureHandler float_handler =
406         [&dtypes](py::handle h,
407                   bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
408       // Only Python native types has a True weak_type.
409       bool weak_type = !py::isinstance(h, dtypes.np_float64);
410       if (jax_enable_x64) {
411         return ArgSignature(xla::PrimitiveType::F64, {}, weak_type);
412       } else {
413         return ArgSignature(xla::PrimitiveType::F32, {}, weak_type);
414       }
415     };
416     ToArgSignatureHandler complex_handler =
417         [&dtypes](py::handle h,
418                   bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
419       // Note that this branch is also taken  for np.complex128:
420       // isinstance(np.complex128(3), complex) returns True
421       // isinstance(np.complex64(3), complex) returns False
422       bool weak_type = !py::isinstance(h, dtypes.np_complex128);
423       if (jax_enable_x64) {
424         return ArgSignature(xla::PrimitiveType::C128, {}, weak_type);
425       } else {
426         return ArgSignature(xla::PrimitiveType::C64, {}, weak_type);
427       }
428     };
429 
430     (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
431     (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
432     (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
433     (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
434 
435     // The Buffer types
436     // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
437     ToArgSignatureHandler buffer_handler =
438         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
439       xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
440       bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
441       return ArgSignature(buffer->buffer()->on_device_shape().element_type(),
442                           buffer->buffer()->on_device_shape().dimensions(),
443                           weak_type);
444     };
445     (*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
446     ToArgSignatureHandler device_array_handler =
447         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
448       py::handle aval = h.attr("aval");
449       TF_ASSIGN_OR_RETURN(auto dtype,
450                           xla::DtypeToPrimitiveType(aval.attr("dtype")));
451       return ArgSignature(dtype,
452                           py::cast<std::vector<xla::int64>>(aval.attr("shape")),
453                           py::cast<py::bool_>(aval.attr("weak_type")));
454     };
455     // ShardedDeviceArray is covered by the MRO fallback on _DeviceArray.
456     (*p)[device_array.ptr()] = device_array_handler;
457 
458     ToArgSignatureHandler numpy_handler =
459         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
460       py::array numpy_array = py::cast<py::array>(h);
461       if (IsFloat0(numpy_array)) {
462         return xla::InvalidArgument(
463             "float0 numpy arrays not supported in C++. "
464             "Falling back to Python.");
465       }
466       if (!jax_enable_x64) {
467         const py::dtype raw_dtype = numpy_array.dtype();
468         const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
469 
470         xla::PrimitiveType dtype;
471         if (to_dtype) {
472           TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype));
473         } else {
474           TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype));
475         }
476         // We need the reinterpret_cast for the OSS version to build.
477         return ArgSignature(
478             dtype,
479             absl::MakeConstSpan(
480                 reinterpret_cast<const xla::int64*>(numpy_array.shape()),
481                 numpy_array.ndim()),
482             /*weak_type=*/false);
483       }
484       TF_ASSIGN_OR_RETURN(auto dtype,
485                           xla::DtypeToPrimitiveType(numpy_array.dtype()));
486       return ArgSignature(
487           dtype,
488           absl::MakeConstSpan(
489               reinterpret_cast<const xla::int64*>(numpy_array.shape()),
490               numpy_array.ndim()),
491           /*weak_type=*/false);
492     };
493     const auto numpy = py::module::import("numpy");
494     const auto& ndarray = numpy.attr("ndarray");
495     (*p)[ndarray.ptr()] = numpy_handler;
496 
497     ToArgSignatureHandler np_uint64_handler =
498         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
499       if (jax_enable_x64) {
500         return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false);
501       } else {
502         return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false);
503       }
504     };
505     ToArgSignatureHandler np_int_handler =
506         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
507       if (jax_enable_x64) {
508         return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false);
509       } else {
510         return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false);
511       }
512     };
513     ToArgSignatureHandler numpy_array_handler =
514         [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
515       // This block deals with all numpy scalar types, except for int64_dt,
516       // float64_dt and complex128_dt which are taken care of in previous if
517       // blocks.
518       TF_ASSIGN_OR_RETURN(auto dtype,
519                           xla::DtypeToPrimitiveType(h.attr("dtype")));
520       return ArgSignature(dtype, {}, /*weak_type=*/false);
521     };
522 
523     // This block deals with all numpy scalar types, except for int64_dt,
524     // float64_dt and complex128_dt which are taken care of in previous if
525     // blocks.
526     (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
527     (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
528     (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
529     (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
530     (*p)[dtypes.np_int64.ptr()] = np_int_handler;
531     (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
532     (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
533     (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
534     (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
535     (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
536     (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
537     (*p)[dtypes.np_float64.ptr()] = float_handler;
538     (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
539     (*p)[dtypes.np_complex128.ptr()] = complex_handler;
540     (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
541     (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
542 
543     return p;
544   }();
545 
546   auto res = handlers->find(arg.get_type().ptr());
547   if (res == handlers->end()) {
548     // We attempt to look at the MRO classes
549     for (auto base_class : arg.get_type().attr("mro")()) {
550       res = handlers->find(base_class.ptr());
551       if (res != handlers->end()) {
552         return res->second(arg, jax_enable_x64);
553       }
554     }
555     return xla::InvalidArgument(
556         "%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
557                            "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
558                            "arrays scalars of supported types "
559                            "(see implementation), or Python scalars. Got type ",
560                            py::cast<std::string>(py::str(arg.get_type()))));
561   } else {
562     return res->second(arg, jax_enable_x64);
563   }
564 }
565 
566 namespace {
567 using DevicePutFunc = std::function<xla::StatusOr<DevicePutResult>(
568     py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>;
569 
HandleBool(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)570 DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device,
571                            bool jax_enable_x64, xla::PyClient& pyclient) {
572   return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
573                              h, pyclient.pjrt_client(), to_device),
574                          /*weak_type=*/true);
575 }
576 
HandleInt(py::handle obj,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)577 DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device,
578                           bool jax_enable_x64, xla::PyClient& pyclient) {
579   if (jax_enable_x64) {
580     return DevicePutResult(ConvertToScalarBuffer<xla::int64, py::int_>(
581                                obj, pyclient.pjrt_client(), to_device),
582                            /*weak_type=*/true);
583   } else {
584     return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
585                                obj, pyclient.pjrt_client(), to_device),
586                            /*weak_type=*/true);
587   }
588 }
589 
590 template <bool weak_type>
HandleFloat(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)591 xla::StatusOr<DevicePutResult> HandleFloat(py::handle h,
592                                            xla::PjRtDevice* to_device,
593                                            bool jax_enable_x64,
594                                            xla::PyClient& pyclient) {
595   if (jax_enable_x64) {
596     return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
597                                h, pyclient.pjrt_client(), to_device),
598                            /*weak_type=*/weak_type);
599   } else {
600     return DevicePutResult(ConvertToScalarBuffer<float, py::float_>(
601                                h, pyclient.pjrt_client(), to_device),
602                            /*weak_type=*/weak_type);
603   }
604 }
605 
606 template <bool weak_type>
HandleComplex(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)607 xla::StatusOr<DevicePutResult> HandleComplex(py::handle h,
608                                              xla::PjRtDevice* to_device,
609                                              bool jax_enable_x64,
610                                              xla::PyClient& pyclient) {
611   // This branch is also taken  for np.complex128:
612   // isinstance(np.complex128(3), complex) returns True
613   // isinstance(np.complex64(3), complex) returns False
614   Py_complex result = PyComplex_AsCComplex(h.ptr());
615   if (result.real == -1.0 && PyErr_Occurred()) {
616     PyErr_Clear();
617     throw std::runtime_error("Could not convert the complex number");
618   }
619   if (jax_enable_x64) {
620     xla::complex128 data(result.real, result.imag);
621     xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
622     return DevicePutResult(
623         ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
624             &data, shape,
625             xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
626             nullptr, to_device)),
627         /*weak_type=*/weak_type);
628   } else {
629     xla::complex64 data(result.real, result.imag);
630     xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
631     return DevicePutResult(
632         ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
633             &data, shape,
634             xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
635             nullptr, to_device)),
636         /*weak_type=*/weak_type);
637   }
638 }
639 
HandleDeviceArray(py::handle obj,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)640 xla::StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
641                                                  xla::PjRtDevice* to_device,
642                                                  bool jax_enable_x64,
643                                                  xla::PyClient& pyclient) {
644   if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
645     return xla::InvalidArgument(
646         "Non-trivial lazy expression not supported in C++. "
647         "Falling back to Python.");
648   }
649   xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
650   bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
651   // Same block as in the previous `if (is_py_buffer)`.
652   if (buffer->device().contents == to_device) {
653     return DevicePutResult(buffer->buffer(), weak_type);
654   } else {
655     std::unique_ptr<xla::PjRtBuffer> copied_buffer =
656         ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
657     return DevicePutResult(std::move(copied_buffer), weak_type);
658   }
659 }
660 
661 // Do not convert types, and only call PjRtBufferFromPyval, independently
662 // of the value of jax_enable_x64.
HandleBufferFromPyval(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)663 DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device,
664                                       bool jax_enable_x64,
665                                       xla::PyClient& pyclient) {
666   std::unique_ptr<xla::PjRtBuffer> buffer =
667       ValueOrThrow(pyclient.PjRtBufferFromPyval(
668           h, to_device,
669           /*force_copy=*/false, /*host_buffer_semantics=*/
670           xla::PjRtClient::HostBufferSemantics::kZeroCopy));
671   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
672 }
673 
HandleNpBool(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)674 DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device,
675                              bool jax_enable_x64, xla::PyClient& pyclient) {
676   if (jax_enable_x64) {
677     return DevicePutResult(ConvertToScalarBuffer<xla::int64, py::int_>(
678                                h, pyclient.pjrt_client(), to_device),
679                            /*weak_type=*/false);
680   } else {
681     return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
682                                h, pyclient.pjrt_client(), to_device),
683                            /*weak_type=*/false);
684   }
685 }
686 
HandleUint64(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)687 DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device,
688                              bool jax_enable_x64, xla::PyClient& pyclient) {
689   if (jax_enable_x64) {
690     std::unique_ptr<xla::PjRtBuffer> buffer =
691         ValueOrThrow(pyclient.PjRtBufferFromPyval(
692             h, to_device,
693             /*force_copy=*/false, /*host_buffer_semantics=*/
694             xla::PjRtClient::HostBufferSemantics::kZeroCopy));
695     return DevicePutResult(std::move(buffer), /*weak_type=*/false);
696   } else {
697     static const auto* numpy = new py::module(py::module::import("numpy"));
698     const auto& np_array = numpy->attr("array");
699 
700     // Note that this is calling back to Python!
701     std::unique_ptr<xla::PjRtBuffer> buffer =
702         ValueOrThrow(pyclient.PjRtBufferFromPyval(
703             np_array(h, py::dtype("uint32")), to_device,
704             /*force_copy=*/false, /*host_buffer_semantics=*/
705             xla::PjRtClient::HostBufferSemantics::kZeroCopy));
706     return DevicePutResult(std::move(buffer), /*weak_type=*/false);
707   }
708 }
709 
HandleNdarray(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)710 xla::StatusOr<DevicePutResult> HandleNdarray(py::handle h,
711                                              xla::PjRtDevice* to_device,
712                                              bool jax_enable_x64,
713                                              xla::PyClient& pyclient) {
714   py::array numpy_array = py::cast<py::array>(h);
715   if (IsFloat0(numpy_array)) {
716     return xla::InvalidArgument("%s",
717                                 "float0 numpy arrays not supported in C++. "
718                                 "Falling back to Python.");
719   }
720   // If jax_enable_x64 is not set, we need to coerce 32 bits types.
721   // Note that this is calling back to Python!
722   if (!jax_enable_x64) {
723     const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
724     if (to_dtype) {
725       static const auto* numpy = new py::module(py::module::import("numpy"));
726       const auto& np_array = numpy->attr("array");
727       numpy_array = np_array(numpy_array, *to_dtype);
728     }
729   }
730   std::unique_ptr<xla::PjRtBuffer> buffer =
731       ValueOrThrow(pyclient.PjRtBufferFromPyval(
732           numpy_array, to_device,
733           /*force_copy=*/false, /*host_buffer_semantics=*/
734           xla::PjRtClient::HostBufferSemantics::kZeroCopy));
735   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
736 }
737 
738 }  // namespace
739 
DevicePut(pybind11::handle arg,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)740 xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
741                                          xla::PjRtDevice* to_device,
742                                          bool jax_enable_x64,
743                                          xla::PyClient& pyclient) {
744   static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
745       [] {
746         auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
747 
748         const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
749 
750         const auto numpy = py::module::import("numpy");
751         const auto xla_module = py::module::import("jax.interpreters.xla");
752         const auto& device_array = xla_module.attr("_DeviceArray");
753 
754         // Python base types.
755         (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = HandleBool;
756         (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandleInt;
757         (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = HandleFloat<true>;
758         (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
759             HandleComplex<true>;
760 
761         // DeviceArray and co.
762         const auto pxla_module = py::module::import("jax.interpreters.pxla");
763         const auto& sda = pxla_module.attr("ShardedDeviceArray");
764         (*p)[device_array.ptr()] = HandleDeviceArray;
765         (*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] =
766             HandleDeviceArray;
767         (*p)[sda.ptr()] = HandleBufferFromPyval;
768         // Numpy arrays.
769         (*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
770 
771         // Numpy scalar types. For some of them, we share the handler with
772         // Python types (np_int64, np_float64, np_complex128).
773         (*p)[dtypes.np_bool.ptr()] = HandleBufferFromPyval;
774         (*p)[dtypes.np_int8.ptr()] = HandleBufferFromPyval;
775         (*p)[dtypes.np_int16.ptr()] = HandleBufferFromPyval;
776         (*p)[dtypes.np_int32.ptr()] = HandleBufferFromPyval;
777         (*p)[dtypes.np_int64.ptr()] = HandleNpBool;
778         (*p)[dtypes.np_uint8.ptr()] = HandleBufferFromPyval;
779         (*p)[dtypes.np_uint16.ptr()] = HandleBufferFromPyval;
780         (*p)[dtypes.np_uint32.ptr()] = HandleBufferFromPyval;
781         (*p)[dtypes.np_uint64.ptr()] = HandleUint64;
782         (*p)[dtypes.np_float16.ptr()] = HandleBufferFromPyval;
783         (*p)[dtypes.np_float32.ptr()] = HandleBufferFromPyval;
784         (*p)[dtypes.np_float64.ptr()] = HandleFloat<false>;
785         (*p)[dtypes.np_complex64.ptr()] = HandleBufferFromPyval;
786         (*p)[dtypes.np_complex128.ptr()] = HandleComplex<false>;
787         (*p)[dtypes.np_longlong.ptr()] = HandleNpBool;
788         (*p)[dtypes.np_intc.ptr()] = HandleBufferFromPyval;
789 
790         return p;
791       }();
792 
793   auto res = handlers->find(arg.get_type().ptr());
794   if (res == handlers->end()) {
795     for (auto base_class : arg.get_type().attr("mro")()) {
796       res = handlers->find(base_class.ptr());
797       if (res != handlers->end()) {
798         return res->second(arg, to_device, jax_enable_x64, pyclient);
799       }
800     }
801     return xla::InvalidArgument(
802         "%s", absl::StrCat(
803                   "Not supported: The C++ jax jit execution path, only accepts "
804                   "DeviceArray, Numpy arrays scalars of supported types "
805                   "(see implementation), or Python scalars. Got type ",
806                   py::cast<std::string>(py::str(arg.get_type()))));
807   } else {
808     return res->second(arg, to_device, jax_enable_x64, pyclient);
809   }
810 }
811 
812 namespace {
813 
814 struct CacheEntry {
815   std::shared_ptr<xla::PyExecutable> executable;
816   xla::PyTreeDef out_pytree_def;
817   // We use Python types within the vector because this is what we will be
818   // returning to Python. No need to convert back and forth.
819   // We need py::object to maintain the objects alive.
820   std::vector<py::object> out_avals;
821   // The processing done in `AddCacheEntry` ensures that LazyExpr are stored as
822   // `py::none()`.
823   std::vector<py::object> out_lazy_exprs;
824   py::object sticky_device;
825 
826   // Ensures a single thread performs the compilation for a given executable.
827   //
828   // The first thread (holding the GIL) will create the CacheEntry associated to
829   // a signature and if the object has been insterted already, other threads
830   // will wait for the notification.
831   absl::Notification compilation_complete;
832   absl::optional<xla::Status> compilation_error = absl::nullopt;
833   // Trivial computation will fallback to Python.
834   // Running a jax(pmap) will also fallback to Python.
835   bool fall_back_to_python = false;
836 };
837 
838 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
839 // bookkeeping of the different signatures used and the dispatch of calls to
840 // the correct underlying `PyExecutable`. This class is thread-safe.
841 class CompiledFunction {
842  public:
843   CompiledFunction(py::function fun, py::function cache_miss,
844                    py::function get_device, py::function get_jax_enable_x64,
845                    py::function get_jax_disable_jit,
846                    std::vector<int> static_argnums);
847   ~CompiledFunction();
848 
849   // This function will:
850   // (a) flatten the inputs using pytree
851   // (b) get buffer objects from the arguments
852   // (c) call the executable
853   // (d) construct `DeviceArray` objects from the outputs
854   // (e) reconstruct the `PyTree`.
855   py::object Call(py::args args, py::kwargs kwargs);
856 
857   // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()858   py::object PythonSignature() {
859     static const auto* inspect = new py::module(py::module::import("inspect"));
860     return inspect->attr("signature")(fun_);
861   }
862 
cache_size() const863   int cache_size() const { return executables_.size(); }
864 
865  private:
866   // Returns nullptr if not present in the cache.
867   CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
868   // Should never return nullptr.
869   CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
870                             const CallSignature& signature,
871                             py::object out_and_fastpath_data);
JitIsDisabled()872   bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
873 
874   bool always_fallback_to_python_ = false;
875 
876   const py::function fun_;  // The Python function to jit.
877   // See JAX _cpp_jit in api.py for documentation.
878   const py::function cache_miss_;
879 
880   // We need to know the static arguments to remove them from the arguments
881   // passed to the underlying PyExecutable. In sorted order.
882   std::vector<int> static_argnums_;
883   // We need a `unique_ptr` here to ensure value pointer stability.
884   absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
885 
886   // As top-level functions are decorated with `jax.jit`, when
887   // `CompiledFunction` is being instantiated from Python, the clients are not
888   // yet available (done after GoogleInit). They will be during the first call
889   // to `Call`.
890   // A function taking no arguments and returning the default device and whether
891   // jax.jit has been committed to it.
892   const py::function get_jax_enable_x64_;
893   const py::function get_jax_disable_jit_;
894   const py::function get_device_;
895 
896   // The writing of the following is protected by the mutex.
897   absl::Mutex mu_;
898   // The value of the Python flag. The value will be computed only during the
899   // first object call, because GoogleInit must have been executed.
900   absl::optional<bool> jax_enable_x64_ = absl::nullopt;
901   absl::optional<bool> jax_disable_jit_ = absl::nullopt;
902 
903   // The logic if the following:
904   // - if `device` or `backend` are not specified to `jax.jit`, we will use
905   //   the input sticky buffer device, or `default_device_` if there is no
906   //   such sticky buffer.
907   // - When one of `device` or `backend` is specified, this will determine
908   //   the `default_device_` which will be used as the targeted device. In
909   //   which case, we will always copy input buffers to this device.
910   std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
911   xla::ClientAndPtr<xla::PjRtDevice> default_pydevice_;
912   xla::PjRtDevice* default_device_ = nullptr;
913   bool is_committed_;
914 };
915 
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,py::function get_jax_enable_x64,py::function get_jax_disable_jit,std::vector<int> static_argnums)916 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
917                                    py::function get_device,
918                                    py::function get_jax_enable_x64,
919                                    py::function get_jax_disable_jit,
920                                    std::vector<int> static_argnums)
921     : fun_(std::move(fun)),
922       cache_miss_(std::move(cache_miss)),
923       static_argnums_(std::move(static_argnums)),
924       get_jax_enable_x64_(get_jax_enable_x64),
925       get_jax_disable_jit_(get_jax_disable_jit),
926       get_device_(std::move(get_device)) {
927   std::sort(static_argnums_.begin(), static_argnums_.end());
928 }
929 
~CompiledFunction()930 CompiledFunction::~CompiledFunction() {
931   for (const auto& entry : executables_) {
932     entry.first.DecRef();
933   }
934 }
935 
936 // Converts flattened arguments contained in ParsedArgumentsAsBuffers in
937 // place. If arguments are `DeviceArray`, they must all be on the same `Device`.
938 //
939 // Returns `Okxla::Status()` on success. Returning an error should lead to
940 // calling the Python fallback.
ConvertArgsToBuffers(bool jax_enable_x64,xla::PyClient & pyclient,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)941 xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
942                                  xla::PjRtDevice* default_device,
943                                  bool is_committed,
944                                  ParsedArgumentsAsBuffers& arguments) {
945   std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
946   auto& keep_alive = arguments.keep_alive;
947 
948   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
949   arg_buffers.reserve(num_flat_dynamic_args);
950   arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args);
951 
952   static const auto* xla_module =
953       new py::module(py::module::import("jax.interpreters.xla"));
954   const auto& device_array = xla_module->attr("_DeviceArray");
955 
956   // When the jitted function is not committed, we first check whether any
957   // sticky `DeviceArray` is present and on which device they live. See also:
958   // https://github.com/google/jax/pull/1884
959   // https://github.com/google/jax/pull/1916 for the rationale why the
960   // computation follows the data locality.
961   // It's also similar to PyTorch's behavior.
962   xla::PjRtDevice* data_device = nullptr;
963   if (is_committed) {
964     data_device = default_device;
965   } else {
966     for (py::handle arg : arguments.flat_dynamic_args) {
967       // We specically only deal with DeviceArray (not ShardedDeviceArray).
968       // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
969       if (py::isinstance<xla::PyBuffer>(arg) ||
970           arg.get_type().is(device_array)) {
971         xla::PyBuffer* buffer;
972         if (arg.attr("_device").is_none()) {  // Skip non-sticky devices.
973           continue;
974         }
975         try {
976           // This can fail, e.g. when device_buffer is a `DeviceConstant`.
977           buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
978         } catch (const py::cast_error& e) {
979           return xla::InvalidArgument(
980               "%s",
981               absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
982                            "`device_buffer` field is of type ",
983                            py::cast<std::string>(
984                                arg.attr("device_buffer").get_type().str()),
985                            " while a `PyBuffer` was expected."
986 
987                            ));
988         }
989         xla::PjRtDevice* device = buffer->buffer()->device();
990         if (data_device && (device != data_device)) {
991           throw std::invalid_argument(absl::StrCat(
992               "primitive arguments must be colocated on the same device ("
993               "C++ jax.jit). Arguments are on devices: ",
994               device->DebugString(), " and ", data_device->DebugString()));
995         } else {
996           data_device = device;
997         }
998       }
999     }
1000   }
1001   if (!data_device) {
1002     // No `DeviceArray` were found default to `default_device`.
1003     data_device = default_device;
1004   }
1005   CHECK(data_device);
1006   arguments.signature.device = data_device;
1007 
1008   for (py::handle arg : arguments.flat_dynamic_args) {
1009     TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
1010                         DevicePut(arg, data_device, jax_enable_x64, pyclient));
1011 
1012     xla::PjRtBuffer* buffer = on_device.buffer;
1013     arg_buffers.push_back(buffer);
1014     if (on_device.owned_buffer) {
1015       keep_alive.emplace_back(std::move(on_device.owned_buffer));
1016     }
1017 
1018     ArgSignature sig(buffer->on_device_shape().element_type(),
1019                      buffer->on_device_shape().dimensions(),
1020                      on_device.weak_type);
1021     arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
1022   }
1023   return xla::Status::OK();
1024 }
1025 
GetCacheEntryIfPresent(const CallSignature & signature)1026 CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
1027     const CallSignature& signature) {
1028   auto found_iterator = executables_.find(signature);
1029   if (found_iterator != executables_.end()) {  // Cache hit!
1030     if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
1031       py::gil_scoped_release gil_release;
1032       found_iterator->second->compilation_complete.WaitForNotification();
1033     }
1034     if (found_iterator->second->compilation_error) {
1035       throw std::invalid_argument(
1036           found_iterator->second->compilation_error.value().error_message());
1037     }
1038     return found_iterator->second.get();
1039   }
1040   return nullptr;
1041 }
1042 
AddCacheEntry(const py::args & args,const py::kwargs & kwargs,const CallSignature & signature,py::object out_and_fastpath_data)1043 CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
1044                                             const py::kwargs& kwargs,
1045                                             const CallSignature& signature,
1046                                             py::object out_and_fastpath_data) {
1047   // We need to insert the element.
1048   auto result = executables_.emplace(signature, std::make_unique<CacheEntry>());
1049   auto it = result.first;
1050   CacheEntry* cache_entry = it->second.get();
1051   // CallSignatures in the cache own their keyword argument reference.
1052   result.first->first.IncRef();
1053 
1054   py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
1055   CHECK_EQ(tuple.size(), 2);
1056   if (tuple[1].is_none()) {
1057     cache_entry->fall_back_to_python = true;
1058     cache_entry->compilation_complete.Notify();
1059     return cache_entry;
1060   }
1061 
1062   py::tuple executable_handlers_out_tree = py::cast<py::tuple>(tuple[1]);
1063   if (executable_handlers_out_tree.size() != 5) {
1064     throw std::runtime_error(absl::StrCat(
1065         "The versions of jaxlib and Jax are incompatible (jaxlib is too recent "
1066         "compared to Jax. Upgrade Jax is advised. The C++ code expects "
1067         "5 arguments but ",
1068         executable_handlers_out_tree.size(), " where provided: ",
1069         py::cast<std::string>(
1070             py::str(py::repr(executable_handlers_out_tree)))));
1071   }
1072   // (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
1073   auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
1074       executable_handlers_out_tree[0]);
1075   cache_entry->executable = std::move(executable);
1076   int num_devices =
1077       cache_entry->executable->pjrt_executable().addressable_devices().size();
1078   // The presence of jit(pmap) is detected from Python.
1079   CHECK_EQ(num_devices, 1);
1080 
1081   auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
1082   cache_entry->out_pytree_def = std::move(out_tree);
1083 
1084   cache_entry->sticky_device =
1085       py::cast<py::object>(executable_handlers_out_tree[2]);
1086   auto avals = py::cast<py::list>(executable_handlers_out_tree[3]);
1087   auto lazy_exprs = py::cast<py::list>(executable_handlers_out_tree[4]);
1088   CHECK_EQ(avals.size(), lazy_exprs.size());
1089 
1090   cache_entry->out_avals.reserve(avals.size());
1091   cache_entry->out_lazy_exprs.reserve(avals.size());
1092   for (int i = 0; i < avals.size(); ++i) {
1093     py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
1094     py::object lazy_expr = py::reinterpret_borrow<py::object>(lazy_exprs[i]);
1095 
1096     cache_entry->out_avals.push_back(shaped_array);
1097     CHECK(lazy_expr.is_none() || !IsTrivialLazyExpr(lazy_expr));
1098     cache_entry->out_lazy_exprs.push_back(lazy_expr);
1099   }
1100 
1101   cache_entry->compilation_complete.Notify();
1102   return cache_entry;
1103 }
1104 
Call(py::args args,py::kwargs kwargs)1105 py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
1106   if (always_fallback_to_python_) {
1107     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1108   }
1109   // Delayed values are retrieved on the first call to `Call`.
1110   if (!default_device_) {
1111     // As we are calling Python code, that may release the GIL, we first hold
1112     // mu_ before holding the GIL.
1113     py::gil_scoped_release gil_release;
1114     {
1115       absl::MutexLock lock1(&mu_);
1116       py::gil_scoped_acquire gil_aquire;
1117 
1118       jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
1119       jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
1120       if (!default_device_) {
1121         py::object device_and_is_committed = get_device_();
1122         try {
1123           default_pydevice_ = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
1124               device_and_is_committed.attr("default_device"));
1125         } catch (const py::cast_error& e) {
1126           // Pathways and Cloud TPU 2VM runtime.
1127           always_fallback_to_python_ = true;
1128           return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1129         }
1130         default_pyclient_ = default_pydevice_.client;
1131         default_device_ = default_pydevice_.contents;
1132         if (!default_device_) {  // UPTC
1133           always_fallback_to_python_ = true;
1134           return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1135         }
1136         is_committed_ =
1137             py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
1138       }
1139     }
1140   }
1141   CHECK(default_device_);
1142   if (JitIsDisabled()) {
1143     return fun_(*args, **kwargs);
1144   }
1145   ParsedArgumentsAsBuffers arguments;
1146   if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
1147     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1148   }
1149 
1150   // The C++ jit do not support Tracers arguments inputs yet. The Python-based
1151   // jit function will be called if any of the dynamic arguments is unsupported.
1152   if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_,
1153                             default_device_, is_committed_, arguments)
1154            .ok()) {
1155     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1156   }
1157 
1158   CacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
1159 
1160   if (!cache_entry) {
1161     py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
1162     cache_entry = GetCacheEntryIfPresent(arguments.signature);
1163     if (!cache_entry) {
1164       cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
1165                                   out_and_fastpath_data);
1166     }
1167     CHECK(cache_entry);
1168     if (cache_entry->fall_back_to_python) {
1169       return py::cast<py::tuple>(out_and_fastpath_data)[0];
1170     }
1171     // As we have already computed the results, we can return it.
1172     // It's even *required* e.g. if there are donated arguments, because
1173     // otherwise the buffer which has been donated already will be invalid.
1174     return py::cast<py::tuple>(out_and_fastpath_data)[0];
1175   }
1176   CHECK(cache_entry);
1177   if (cache_entry->fall_back_to_python) {
1178     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1179   }
1180   std::vector<std::unique_ptr<xla::PyBuffer>> outputs =
1181       ValueOrThrow(cache_entry->executable->PjRtExecute(arguments.arg_buffers));
1182 
1183   const std::vector<py::object>& out_avals = cache_entry->out_avals;
1184   const std::vector<py::object>& out_lazy_exprs = cache_entry->out_lazy_exprs;
1185   const py::object& sticky_device = cache_entry->sticky_device;
1186 
1187   py::list flat_device_arrays;
1188   for (int i = 0; i < outputs.size(); ++i) {
1189     auto& buffer = outputs[i];
1190     if (out_lazy_exprs[i].is_none()) {  // No LazyExpr.
1191       buffer->SetAval(out_avals[i]);
1192       buffer->SetStickyDevice(sticky_device);
1193       flat_device_arrays.append(py::cast(std::move(outputs[i])));
1194     } else {
1195       static const auto* xla_module =
1196           new py::module(py::module::import("jax.interpreters.xla"));
1197       static const auto* device_array =
1198           new py::handle(xla_module->attr("_DeviceArray"));
1199       flat_device_arrays.append(
1200           (*device_array)(out_avals[i], sticky_device, out_lazy_exprs[i],
1201                           py::cast(std::move(outputs[i]))));
1202     }
1203   }
1204   return cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
1205 }
1206 
1207 }  // namespace
1208 
BuildJaxjitSubmodule(pybind11::module & m)1209 void BuildJaxjitSubmodule(pybind11::module& m) {
1210   py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1211 
1212   py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
1213       jitlib, "CompiledFunction");
1214   cfun.def("__call__", &CompiledFunction::Call);
1215   cfun.def_property_readonly("__signature__",
1216                              &CompiledFunction::PythonSignature);
1217 
1218   jitlib.def("set_disable_jit", &SetDisableJit);
1219   jitlib.def("get_disable_jit", &GetDisableJit);
1220   jitlib.def(
1221       "jit",
1222       [](py::function fun, py::function cache_miss, py::function get_device,
1223          py::function get_jax_enable_x64, py::function get_jax_disable_jit,
1224          std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
1225         return std::make_unique<CompiledFunction>(
1226             std::move(fun), std::move(cache_miss), std::move(get_device),
1227             std::move(get_jax_enable_x64), std::move(get_jax_disable_jit),
1228             std::move(static_argnums));
1229       });
1230 
1231   // This function is yet a full replacement for the Python one, because:
1232   // (a) it does not support abstract types,
1233   // (b) it does not set the device stickiness yet.
1234   // TODO(jblespiau): Finish the replacement of the Python feature.
1235   jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
1236                               xla::ClientAndPtr<xla::PjRtDevice> to_device) {
1237     std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1238     xla::StatusOr<DevicePutResult> results =
1239         DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
1240     if (!results.ok()) {
1241       throw std::runtime_error(results.status().error_message());
1242     }
1243     if (results->owned_buffer) {
1244       auto buffer = std::make_unique<xla::PyBuffer>(
1245           pyclient, std::move(results->owned_buffer), xla::Traceback::Get());
1246 
1247       static const auto* jax_core =
1248           new py::module(py::module::import("jax.core"));
1249       static const auto* shaped_array =
1250           new py::handle(jax_core->attr("ShapedArray"));
1251       buffer->SetAval((*shaped_array)(
1252           buffer->python_shape(), buffer->python_dtype(), results->weak_type));
1253       buffer->SetStickyDevice(py::none());
1254 
1255       return py::cast(std::move(buffer));
1256     } else {
1257       return py::cast<py::object>(obj);
1258     }
1259   });
1260 
1261   py::class_<ArgSignature> arg_signature(jitlib, "ArgSignature");
1262   arg_signature
1263       .def_property_readonly("dtype",
1264                              [](const ArgSignature& sig) {
1265                                return PrimitiveTypeToDtype(sig.dtype);
1266                              })
1267       .def_property_readonly("shape",
1268                              [](const ArgSignature& sig) {
1269                                return xla::IntSpanToTuple(sig.shape);
1270                              })
1271       .def_readonly("weak_type", &ArgSignature::weak_type);
1272   jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
1273 
1274   // All private members are only for testing purposes
1275   cfun.def("_cache_size", &CompiledFunction::cache_size);
1276   jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object {
1277     py::dtype dtype = py::dtype::from_args(obj);
1278     const py::dtype* res = DtypeTo32BitDtype(dtype);
1279     if (res) {
1280       return *res;
1281     } else {
1282       return py::none();
1283     }
1284   });
1285   jitlib.def("_is_float0", &IsFloat0);
1286   jitlib.def("_is_trivial", &IsTrivialLazyExpr);
1287 }
1288 
1289 }  // namespace jax
1290