1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 
29 namespace xla {
30 
31 class PyBuffer;
32 class PyClient;
33 class PyExecutable;
34 
35 // Custom holder types.
36 //
37 // We must keep the PyClient object alive as long as any of the runtime
38 // objects are alive. Since we don't have a lot of control over Python
39 // destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
40 // and ensure that each Python runtime object holds a reference to the
41 // PyClient. An alternative design would be to keep a single global
42 // singleton PyClient, although this seems less flexible, especially for
43 // writing tests.
44 //
45 // To maintain PyClient references, we define pybind11 holder classes that
46 // are custom smart pointers that also keep a reference to a PyClient.
47 // pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
48 // seem sufficiently flexible to describe ownership relationships in cases where
49 // the ownership doesn't pertain to a direct argument or return value of a
50 // function. Another alternative to the holder classes would be to create proxy
51 // objects that contain both a reference and a runtime class; holder classes
52 // seem less tedious to define.
53 
54 // A pair of a PyClient reference and an unowned pointer to T.
55 template <typename T>
56 struct ClientAndPtr {
57   ClientAndPtr() = default;
58   // pybind11 requires that we define a constructor that takes a raw pointer,
59   // but it should be unreachable.
ClientAndPtrClientAndPtr60   explicit ClientAndPtr(T*) {
61     LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
62   }
63 
64   ClientAndPtr(const ClientAndPtr&) = default;
65   ClientAndPtr(ClientAndPtr&&) = default;
66   ClientAndPtr& operator=(const ClientAndPtr&) = default;
67   ClientAndPtr& operator=(ClientAndPtr&&) = default;
68 
69   std::shared_ptr<PyClient> client;
70   T* contents;
71 
getClientAndPtr72   T* get() const { return contents; }
73   T* operator->() const { return contents; }
74   T& operator*() const { return *contents; }
75 };
76 
77 // By defining a templated helper function, we can use return type deduction
78 // and avoid specifying types at the caller.
79 template <typename T>
WrapWithClient(std::shared_ptr<PyClient> client,T * contents)80 ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
81   ClientAndPtr<T> result;
82   result.client = std::move(client);
83   result.contents = contents;
84   return result;
85 }
86 
87 // Python wrapper around PjRtClient.
88 // We use a wrapper class to add Python-specific functionality.
89 class PyClient : public std::enable_shared_from_this<PyClient> {
90  public:
91   explicit PyClient(std::unique_ptr<PjRtClient> pjrt_client);
92   explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
93 
pjrt_client()94   PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
shared_pjrt_client()95   std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
96 
platform_name()97   absl::string_view platform_name() const {
98     return pjrt_client_->platform_name();
99   }
addressable_device_count()100   int addressable_device_count() const {
101     return pjrt_client_->addressable_device_count();
102   }
device_count()103   int device_count() const { return pjrt_client_->device_count(); }
task_id()104   int task_id() const { return pjrt_client_->task_id(); }
105 
106   std::vector<ClientAndPtr<PjRtDevice>> Devices();
107   std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
108 
109   std::vector<ClientAndPtr<PyBuffer>> LiveBuffers();
110 
111   StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
112   GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
113 
114   // TODO(skye): delete after all callers can handle 2D output
115   StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
116       int num_replicas);
117 
CreateChannelHandle()118   StatusOr<ChannelHandle> CreateChannelHandle() {
119     return pjrt_client_->CreateChannelHandle();
120   }
CreateDeviceToHostChannelHandle()121   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
122     return pjrt_client_->CreateDeviceToHostChannelHandle();
123   }
CreateHostToDeviceChannelHandle()124   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
125     return pjrt_client_->CreateHostToDeviceChannelHandle();
126   }
127 
128   StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
129       pybind11::handle argument, PjRtDevice* device, bool force_copy,
130       PjRtClient::HostBufferSemantics host_buffer_semantics);
131   StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
132       pybind11::handle argument, PjRtDevice* device, bool force_copy,
133       PjRtClient::HostBufferSemantics host_buffer_semantics);
134 
135   StatusOr<std::shared_ptr<PyExecutable>> Compile(
136       const XlaComputation& computation, CompileOptions options);
137 
138   pybind11::bytes HeapProfile();
139 
140  private:
141   friend class PyBuffer;
142   friend class PyExecutable;
143 
144   std::shared_ptr<PjRtClient> pjrt_client_;
145 
146   // Pointers to intrusive doubly-linked lists of buffers and executables, used
147   // to iterate over all known objects when heap profiling. The list structure
148   // is protected by the GIL.
149   PyBuffer* buffers_ = nullptr;
150   PyExecutable* executables_ = nullptr;
151 };
152 
153 }  // namespace xla
154 
155 PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
156 
157 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
158