1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "pybind11/attr.h"
23 #include "pybind11/cast.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl_bind.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
31 #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
32 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
33 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
34 #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
35 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
36 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
37 #include "tensorflow/compiler/xla/python/dlpack.h"
38 #include "tensorflow/compiler/xla/python/jax_jit.h"
39 #include "tensorflow/compiler/xla/python/ops.h"
40 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
41 #include "tensorflow/compiler/xla/python/pmap_lib.h"
42 #include "tensorflow/compiler/xla/python/profiler.h"
43 #include "tensorflow/compiler/xla/python/py_buffer.h"
44 #include "tensorflow/compiler/xla/python/py_executable.h"
45 #include "tensorflow/compiler/xla/python/py_traceback.h"
46 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
47 #include "tensorflow/compiler/xla/python/pytree.h"
48 #include "tensorflow/compiler/xla/python/types.h"
49 #include "tensorflow/compiler/xla/python/xla_compiler.h"
50 #include "tensorflow/compiler/xla/shape.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/python/lib/core/bfloat16.h"
56 
57 // TODO(phawkins): remove host_id properties after JAX is update to avoid them.
58 
59 namespace xla {
60 namespace {
61 
62 namespace py = pybind11;
63 
IsOptimizedBuild()64 bool IsOptimizedBuild() {
65 #if NDEBUG
66   return true;
67 #else
68   return false;
69 #endif  // NDEBUG
70 }
71 
72 }  // namespace
73 
PYBIND11_MODULE(xla_extension,m)74 PYBIND11_MODULE(xla_extension, m) {
75   CHECK(tensorflow::RegisterNumpyBfloat16());
76 
77   // Types
78   py::enum_<PrimitiveType>(m, "PrimitiveType")
79       .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
80       .value("PRED", PRED)
81       .value("S8", S8)
82       .value("S16", S16)
83       .value("S32", S32)
84       .value("S64", S64)
85       .value("U8", U8)
86       .value("U16", U16)
87       .value("U32", U32)
88       .value("U64", U64)
89       .value("F16", F16)
90       .value("BF16", BF16)
91       .value("F32", F32)
92       .value("F64", F64)
93       .value("C64", C64)
94       .value("C128", C128)
95       .value("TUPLE", TUPLE)
96       .value("OPAQUE_TYPE", OPAQUE_TYPE)
97       .value("TOKEN", TOKEN);
98 
99   m.def("bfloat16_dtype",
100         []() { return py::handle(tensorflow::Bfloat16Dtype()); });
101 
102   // Must be before PyClient.compile.
103   BuildXlaCompilerSubmodule(m);
104 
105   py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
106       m, "Device",
107       "A descriptor of an available device.\n\nSubclasses are used to "
108       "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
109       "have additional properties specific to that device type.")
110       .def_property_readonly(
111           "id", &PjRtDevice::id,
112           "Integer ID of this device.\n\nUnique across all available devices "
113           "of this type, including remote devices on multi-host platforms.")
114       .def_property_readonly("host_id", &PjRtDevice::task_id,
115                              "Integer ID of this device's task.\n\n"
116                              "This is always 0 except on multi-task platforms.")
117       .def_property_readonly("task_id", &PjRtDevice::task_id,
118                              "Integer ID of this device's task.\n\n"
119                              "This is always 0 except on multi-task platforms.")
120       .def_property_readonly("platform",
121                              [](const PjRtDevice& device) {
122                                return device.client()->platform_name();
123                              })
124       .def_property_readonly("device_kind", &PjRtDevice::device_kind)
125       .def_property_readonly(
126           "client",
127           [](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
128       .def("__str__", &PjRtDevice::DebugString)
129       .def("transfer_to_infeed",
130            [](PjRtDevice& device, const LiteralSlice& literal) {
131              GlobalPyRefManager()->CollectGarbage();
132              py::gil_scoped_release gil_release;
133              return device.TransferToInfeed(literal);
134            })
135       .def("transfer_from_outfeed",
136            [](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
137              GlobalPyRefManager()->CollectGarbage();
138              std::shared_ptr<Literal> literal;
139              {
140                py::gil_scoped_release gil_release;
141                Shape shape_with_layout = shape;
142                ShapeUtil::ForEachMutableSubshape(
143                    &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
144                      if (!subshape->has_layout()) {
145                        LayoutUtil::SetToDefaultLayout(subshape);
146                      }
147                    });
148                literal = std::make_shared<Literal>(shape_with_layout);
149                TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
150              }
151              return LiteralToPython(std::move(literal));
152            });
153 
154   py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
155       .def("__repr__", [](const CpuDevice& device) {
156         return absl::StrFormat("CpuDevice(id=%i)", device.id());
157       });
158 
159   py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
160       .def("__repr__", [](const GpuDevice& device) {
161         return absl::StrFormat("GpuDevice(id=%i)", device.id());
162       });
163 
164   py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>>(
165       m, "TpuDevice")
166       .def_property_readonly("host_id", &PjRtTpuDevice::task_id)
167       .def_property_readonly("task_id", &PjRtTpuDevice::task_id)
168       .def_property_readonly(
169           "coords",
170           [](const PjRtTpuDevice& device) -> pybind11::tuple {
171             return IntSpanToTuple(device.coords());
172           },
173           "The coordinates of this TpuDevice's chip in the TPU mesh network.")
174       .def_property_readonly(
175           "core_on_chip", &PjRtTpuDevice::core_on_chip,
176           "The index of this TpuDevice's core on the TPU chip.")
177       .def("__repr__", [](const PjRtTpuDevice& device) {
178         return absl::StrFormat(
179             "TpuDevice(id=%i, host=%i, coords=(%s), core_on_chip=%i)",
180             device.id(), device.task_id(), absl::StrJoin(device.coords(), ","),
181             device.core_on_chip());
182       });
183 
184   // Local XLA client methods.
185 
186   py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
187   alloc_config.def(py::init<>())
188       .def_readwrite("kind", &GpuAllocatorConfig::kind)
189       .def_readwrite("memory_fraction", &GpuAllocatorConfig::memory_fraction)
190       .def_readwrite("preallocate", &GpuAllocatorConfig::preallocate);
191   py::enum_<GpuAllocatorConfig::Kind>(alloc_config, "Kind")
192       .value("DEFAULT", GpuAllocatorConfig::Kind::kDefault)
193       .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
194       .value("BFC", GpuAllocatorConfig::Kind::kBFC);
195 
196   py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
197       .value("IMMUTABLE_ONLY_DURING_CALL",
198              PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
199       .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
200              PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
201       .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
202 
203   py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
204   py_local_client.def_property_readonly("platform", &PyClient::platform_name)
205       .def("device_count", &PyClient::device_count)
206       .def("local_device_count", &PyClient::addressable_device_count)
207       .def("devices", &PyClient::Devices)
208       .def("local_devices", &PyClient::LocalDevices)
209       .def("live_buffers", &PyClient::LiveBuffers)
210       .def("host_id", &PyClient::task_id)
211       .def("task_id", &PyClient::task_id)
212       .def("get_default_device_assignment",
213            &PyClient::GetDefaultDeviceAssignment)
214       // TODO(skye): delete after all callers can handle 2D output
215       .def("get_default_device_assignment",
216            &PyClient::GetDefaultDeviceAssignment1D)
217       .def("create_channel_handle", &PyClient::CreateChannelHandle)
218       .def("create_device_to_host_channel_handle",
219            &PyClient::CreateDeviceToHostChannelHandle)
220       .def("create_host_to_device_channel_handle",
221            &PyClient::CreateHostToDeviceChannelHandle)
222       .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
223            py::arg("device") = nullptr, py::arg("force_copy") = false,
224            py::arg("host_buffer_semantics") =
225                PjRtClient::HostBufferSemantics::kZeroCopy)
226       .def("compile", &PyClient::Compile, py::arg("computation"),
227            py::arg("compile_options") = CompileOptions())
228       .def("heap_profile", &PyClient::HeapProfile);
229 
230   m.def(
231       "get_cpu_client",
232       [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
233         TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
234                             GetCpuClient(asynchronous));
235         return std::make_shared<PyClient>(std::move(client));
236       },
237       py::arg("asynchronous") = true);
238   m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
239     TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
240                         GetInterpreterClient());
241     return std::make_shared<PyClient>(std::move(client));
242   });
243   m.def(
244       "get_gpu_client",
245       [](bool asynchronous, const GpuAllocatorConfig& allocator_config,
246          std::shared_ptr<DistributedRuntimeClient> distributed_client,
247          int node_id) -> StatusOr<std::shared_ptr<PyClient>> {
248         TF_ASSIGN_OR_RETURN(
249             std::unique_ptr<PjRtClient> client,
250             GetGpuClient(asynchronous, allocator_config,
251                          std::move(distributed_client), node_id));
252         return std::make_shared<PyClient>(std::move(client));
253       },
254       py::arg("asynchronous") = true,
255       py::arg("allocator_config") = GpuAllocatorConfig(),
256       py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
257   m.def(
258       "get_tpu_client",
259       [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
260         TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
261                             GetTpuClient(asynchronous));
262         return std::make_shared<PyClient>(std::move(client));
263       },
264       py::arg("asynchronous") = true);
265 
266   py::class_<DeviceArrayBase> device_array_base(m, "DeviceArrayBase");
267   device_array_base.def(py::init<>());
268 
269   py::class_<PyBuffer, DeviceArrayBase, std::unique_ptr<PyBuffer>> buffer(
270       m, "Buffer");
271   // TODO(phawkins): alias for backward compatibility. Remove after JAX no
272   // longer uses this name.
273   m.add_object("PyLocalBuffer", buffer);
274   buffer
275       .def_property_readonly("__array_priority__",
276                              [](py::object) { return 100; })
277       .def_property("_device", &PyBuffer::GetStickyDevice,
278                     &PyBuffer::SetStickyDevice)
279       .def_property("aval", &PyBuffer::GetAval, &PyBuffer::SetAval)
280       .def_property_readonly("_lazy_expr",
281                              [](py::object buffer) { return py::none(); })
282       .def_property_readonly("device_buffer",
283                              [](py::object buffer) { return buffer; })
284       .def_property_readonly(
285           "shape",
286           [](const PyBuffer& pybuffer) -> pybind11::tuple {
287             return IntSpanToTuple(
288                 pybuffer.buffer()->on_device_shape().dimensions());
289           })
290       .def_property_readonly(
291           "dtype",
292           [](const PyBuffer& buffer) {
293             PrimitiveType primitive =
294                 buffer.buffer()->on_device_shape().element_type();
295             return PrimitiveTypeToDtype(primitive).ValueOrDie();
296           })
297       .def_property_readonly("size", &PyBuffer::size)
298       .def_property_readonly("ndim", &PyBuffer::ndim)
299       .def_property_readonly(
300           "_value",
301           [](py::handle buffer_obj) -> StatusOr<pybind11::object> {
302             GlobalPyRefManager()->CollectGarbage();
303             PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
304             return buffer->AsNumPyArray(buffer_obj);
305           })
306       .def("copy_to_device", &PyBuffer::CopyToDevice)
307       .def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
308       .def("delete", &PyBuffer::Delete)
309       // The GIL is released within BlockHostUntilReady.
310       .def("block_until_ready",
311            [](py::object buffer_obj) -> xla::StatusOr<py::object> {
312              PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
313              TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
314              return buffer_obj;
315            })
316       .def("block_host_until_ready", &PyBuffer::BlockHostUntilReady)
317       .def("copy_to_host_async", &PyBuffer::CopyToHostAsync)
318       .def("to_py",
319            [](py::handle buffer_obj) {
320              PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
321              return buffer->AsNumPyArray(buffer_obj);
322            })
323       .def("xla_shape", &PyBuffer::shape)
324       .def_property_readonly("client", &PyBuffer::client)
325       .def("device", &PyBuffer::device)
326       .def("platform", &PyBuffer::platform_name)
327       .def("is_deleted", &PyBuffer::is_deleted)
328       .def("unsafe_buffer_pointer", &PyBuffer::UnsafeBufferPointer)
329       .def_property_readonly("__cuda_array_interface__",
330                              &PyBuffer::CudaArrayInterface)
331       .def_property_readonly("traceback", &PyBuffer::traceback);
332 
333   // pybind11's implementation of the buffer protocol doesn't allow for correct
334   // error handling. We bypass it and implement the buffer protocol ourselves.
335   PyTypeObject* buffer_type = reinterpret_cast<PyTypeObject*>(buffer.ptr());
336   buffer_type->tp_as_buffer = PyBuffer::BufferProtocol();
337 
338   py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
339       m, "Executable");
340   executable.def_property_readonly("client", &PyExecutable::client)
341       .def("local_logical_device_ids",
342            [](PyExecutable* exec) {
343              auto span = exec->addressable_device_logical_ids();
344              // Not on dispatch critical path, so ok to have heap allocation.
345              std::vector<std::pair<int, int>> addressable_device_logic_ids;
346              addressable_device_logic_ids.reserve(span.size());
347              for (const auto& logical_device_id : span) {
348                addressable_device_logic_ids.push_back(std::make_pair(
349                    logical_device_id.replica, logical_device_id.partition));
350              }
351            })
352       .def("local_devices", &PyExecutable::AddressableDevices)
353       .def("size_of_generated_code_in_bytes",
354            &PyExecutable::SizeOfGeneratedCodeInBytes)
355       .def("delete", &PyExecutable::Delete)
356       .def("execute", &PyExecutable::Execute, py::arg("arguments"))
357       .def("execute_on_local_devices", &PyExecutable::ExecuteOnLocalDevices,
358            py::arg("arguments"))
359       .def("execute_sharded_on_local_devices",
360            &PyExecutable::ExecuteShardedOnLocalDevices, py::arg("arguments"))
361       .def("hlo_modules", &PyExecutable::HloModules)
362       .def_property_readonly("traceback", &PyExecutable::traceback);
363 
364   m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor,
365         py::arg("buffer"), py::arg("take_ownership") = true);
366   m.def("dlpack_managed_tensor_to_buffer", DLPackManagedTensorToBuffer);
367 
368   BuildProfilerSubmodule(&m);
369   BuildOpsSubmodule(&m);
370   BuildOutfeedReceiverSubmodule(&m);
371   BuildPytreeSubmodule(m);
372   jax::BuildJaxjitSubmodule(m);
373   jax::BuildPmapSubmodule(m);
374   BuildTracebackSubmodule(m);
375 
376   py::class_<DistributedRuntimeService,
377              std::unique_ptr<DistributedRuntimeService>>
378       distributed_runtime_service(m, "DistributedRuntimeService");
379   py::class_<DistributedRuntimeClient,
380              std::shared_ptr<DistributedRuntimeClient>>
381       distributed_runtime_client(m, "DistributedRuntimeClient");
382   distributed_runtime_client.def("connect", &DistributedRuntimeClient::Connect)
383       .def("shutdown", &DistributedRuntimeClient::Shutdown);
384 
385   m.def("get_distributed_runtime_service", &GetDistributedRuntimeService);
386   m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient);
387 
388   m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
389 
390   m.def("is_optimized_build", &IsOptimizedBuild);
391 }  // NOLINT(readability/fn_size)
392 
393 }  // namespace xla
394