1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_client.h"
17 
18 #include <memory>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/py_executable.h"
23 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
24 #include "tensorflow/compiler/xla/python/traceback.h"
25 #include "tensorflow/compiler/xla/python/types.h"
26 #include "tensorflow/core/profiler/profile.pb.h"
27 
28 namespace xla {
29 
30 namespace py = pybind11;
31 namespace pprof = tensorflow::tfprof::pprof;
32 
PyClient(std::unique_ptr<PjRtClient> pjrt_client)33 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
34     : pjrt_client_(std::move(pjrt_client)) {}
PyClient(std::shared_ptr<PjRtClient> pjrt_client)35 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
36     : pjrt_client_(std::move(pjrt_client)) {}
37 
Devices()38 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
39   std::vector<ClientAndPtr<PjRtDevice>> devices;
40   auto span = pjrt_client_->devices();
41   devices.reserve(span.size());
42   for (PjRtDevice* device : span) {
43     devices.push_back(WrapWithClient(shared_from_this(), device));
44   }
45   return devices;
46 }
47 
LocalDevices()48 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
49   std::vector<ClientAndPtr<PjRtDevice>> devices;
50   devices.reserve(pjrt_client_->addressable_devices().size());
51   for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
52     devices.push_back(WrapWithClient(shared_from_this(), device));
53   }
54   return devices;
55 }
56 
LiveBuffers()57 std::vector<ClientAndPtr<PyBuffer>> PyClient::LiveBuffers() {
58   CHECK(PyGILState_Check());
59   std::vector<ClientAndPtr<PyBuffer>> buffers;
60   for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
61     if (!buffer->is_deleted()) {
62       buffers.push_back(WrapWithClient(shared_from_this(), buffer));
63     }
64   }
65   return buffers;
66 }
67 
68 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)69 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
70   TF_ASSIGN_OR_RETURN(
71       DeviceAssignment device_assignment,
72       pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
73   std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
74   result.resize(num_replicas);
75   for (int r = 0; r < num_replicas; ++r) {
76     result[r].resize(num_partitions);
77     for (int p = 0; p < num_partitions; ++p) {
78       int device_id = device_assignment(r, p);
79       TF_ASSIGN_OR_RETURN(PjRtDevice * device,
80                           pjrt_client_->LookupDevice(device_id));
81       result[r][p] = WrapWithClient(shared_from_this(), device);
82     }
83   }
84   return result;
85 }
86 
87 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)88 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
89   TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
90                       pjrt_client_->GetDefaultDeviceAssignment(
91                           num_replicas, /*num_partitions=*/1));
92   std::vector<ClientAndPtr<PjRtDevice>> result;
93   for (int i = 0; i < num_replicas; ++i) {
94     int device_id = device_assignment(i, 0);
95     TF_ASSIGN_OR_RETURN(PjRtDevice * device,
96                         pjrt_client_->LookupDevice(device_id));
97     result.push_back(WrapWithClient(shared_from_this(), device));
98   }
99   return result;
100 }
101 
PjRtBufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)102 StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
103     pybind11::handle argument, PjRtDevice* device, bool force_copy,
104     PjRtClient::HostBufferSemantics host_buffer_semantics) {
105   if (device == nullptr) {
106     TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
107     device = pjrt_client_->addressable_devices().front();
108   }
109   CHECK(device != nullptr);
110   TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
111                       pjrt_client_->LookupDevice(device->id()));
112   if (found_device != device) {
113     return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
114                            device->DebugString(),
115                            pjrt_client_->platform_name());
116   }
117   GlobalPyRefManager()->CollectGarbage();
118 
119   absl::optional<CastToArrayResult> c = CastToArray(argument);
120   if (!c) {
121     return InvalidArgument(
122         "from_python argument must be an array, got value %s",
123         py::cast<std::string>(py::repr(argument)));
124   }
125 
126   std::function<void()> on_done_with_host_buffer;
127   if (host_buffer_semantics !=
128       PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) {
129     std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
130         GlobalPyRefManager()->ManageReference(std::move(c->array));
131     on_done_with_host_buffer =
132         [py_buffer_ref{
133             std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
134   }
135 
136   std::unique_ptr<PjRtBuffer> buffer;
137   {
138     py::gil_scoped_release gil_release;
139     TF_ASSIGN_OR_RETURN(buffer,
140                         pjrt_client_->BufferFromHostBuffer(
141                             c->buf_ptr, c->shape, host_buffer_semantics,
142                             std::move(on_done_with_host_buffer), device));
143   }
144   return buffer;
145 }
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)146 StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
147     pybind11::handle argument, PjRtDevice* device, bool force_copy,
148     PjRtClient::HostBufferSemantics host_buffer_semantics) {
149   TF_ASSIGN_OR_RETURN(
150       std::unique_ptr<PjRtBuffer> buffer,
151       PjRtBufferFromPyval(argument, device, force_copy, host_buffer_semantics));
152 
153   auto traceback = Traceback::Get();
154   return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
155                                     std::move(traceback));
156 }
157 
Compile(const XlaComputation & computation,CompileOptions options)158 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
159     const XlaComputation& computation, CompileOptions options) {
160   std::unique_ptr<PjRtExecutable> executable;
161   absl::optional<std::string> fingerprint;
162   {
163     py::gil_scoped_release gil_release;
164     TF_ASSIGN_OR_RETURN(executable,
165                         pjrt_client_->Compile(computation, std::move(options)));
166     TF_ASSIGN_OR_RETURN(fingerprint,
167                         pjrt_client_->ExecutableFingerprint(*executable));
168   }
169   auto traceback = Traceback::Get();
170   return std::make_shared<PyExecutable>(
171       shared_from_this(), std::move(executable), std::move(traceback),
172       std::move(fingerprint));
173 }
174 
175 class ProfileBuilder {
176  public:
177   ProfileBuilder();
profile()178   pprof::Profile& profile() { return profile_; }
179 
180   // Adds or returns the ID of `s` in the table.
181   int StringId(const std::string& s);
182 
183   // Adds or returns the ID of a function.
184   int FunctionId(PyCodeObject* code);
185 
186   // Adds or returns the ID of a code location.
187   int LocationId(PyCodeObject* code, int instruction);
188 
189  private:
190   pprof::Profile profile_;
191 
192   absl::flat_hash_map<std::string, int> strings_;
193   absl::flat_hash_map<PyCodeObject*, int> functions_;
194   absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
195 };
196 
ProfileBuilder()197 ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
198 
StringId(const std::string & s)199 int ProfileBuilder::StringId(const std::string& s) {
200   auto ret = strings_.emplace(s, profile_.string_table_size());
201   if (ret.second) {
202     profile_.add_string_table(s);
203   }
204   return ret.first->second;
205 }
206 
FunctionId(PyCodeObject * code)207 int ProfileBuilder::FunctionId(PyCodeObject* code) {
208   // +1 because id 0 is reserved.
209   auto ret = functions_.emplace(code, profile_.function_size() + 1);
210   if (ret.second) {
211     auto* function = profile_.add_function();
212     function->set_id(ret.first->second);
213     int name = StringId(py::str(code->co_name));
214     function->set_name(name);
215     function->set_system_name(name);
216     function->set_filename(StringId(py::str(code->co_filename)));
217     function->set_start_line(code->co_firstlineno);
218   }
219   return ret.first->second;
220 }
221 
LocationId(PyCodeObject * code,int instruction)222 int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
223   // +1 because id 0 is reserved.
224   auto ret = locations_.emplace(std::make_pair(code, instruction),
225                                 profile_.location_size() + 1);
226   if (ret.second) {
227     auto* location = profile_.add_location();
228     location->set_id(ret.first->second);
229     auto* line = location->add_line();
230     line->set_function_id(FunctionId(code));
231     line->set_line(PyCode_Addr2Line(code, instruction));
232   }
233   return ret.first->second;
234 }
235 
236 namespace {
237 
238 struct HeapProfileKey {
239   Traceback* traceback;
240   int64 size;
241   PjRtDevice* device;
242   bool operator==(const HeapProfileKey& other) const;
243 };
244 
operator ==(const HeapProfileKey & other) const245 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
246   if (size != other.size || device != other.device) {
247     return false;
248   }
249   if ((traceback == nullptr) != (other.traceback == nullptr)) {
250     return false;
251   }
252   if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
253     return false;
254   }
255   return true;
256 }
257 
258 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)259 H AbslHashValue(H h, const HeapProfileKey& key) {
260   if (key.traceback) {
261     h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
262                               key.traceback->raw_frames().size());
263   }
264   h = H::combine(std::move(h), key.size, key.device);
265   return h;
266 }
267 
268 }  // namespace
269 
HeapProfile()270 py::bytes PyClient::HeapProfile() {
271   CHECK(PyGILState_Check());
272   absl::flat_hash_map<HeapProfileKey, int64> entries;
273   for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
274     HeapProfileKey key{buffer->traceback(),
275                        buffer->buffer()->OnDeviceSizeInBytes(),
276                        buffer->buffer()->device()};
277     ++entries[key];
278   }
279 
280   for (PyExecutable* executable = executables_; executable;
281        executable = executable->next_) {
282     HeapProfileKey key{executable->traceback(),
283                        executable->SizeOfGeneratedCodeInBytes(), nullptr};
284     ++entries[key];
285   }
286 
287   ProfileBuilder builder;
288   auto* allocations = builder.profile().add_sample_type();
289   allocations->set_type(builder.StringId("allocations"));
290   allocations->set_unit(builder.StringId("count"));
291   auto* space = builder.profile().add_sample_type();
292   space->set_type(builder.StringId("space"));
293   space->set_unit(builder.StringId("bytes"));
294 
295   const int kind_string_id = builder.StringId("kind");
296   const int buffer_string_id = builder.StringId("buffer");
297   const int executable_string_id = builder.StringId("executable");
298   const int device_string_id = builder.StringId("device");
299   for (const auto& entry : entries) {
300     auto* sample = builder.profile().add_sample();
301     if (entry.first.traceback) {
302       for (const auto& frame : entry.first.traceback->raw_frames()) {
303         sample->add_location_id(builder.LocationId(frame.first, frame.second));
304       }
305     }
306     sample->add_value(entry.second);
307     sample->add_value(entry.first.size * entry.second);
308 
309     auto* kind_label = sample->add_label();
310     kind_label->set_key(kind_string_id);
311     if (entry.first.device) {
312       kind_label->set_str(buffer_string_id);
313       auto* device_label = sample->add_label();
314       device_label->set_key(device_string_id);
315       device_label->set_str(
316           builder.StringId(entry.first.device->DebugString()));
317     } else {
318       kind_label->set_str(executable_string_id);
319     }
320   }
321   return builder.profile().SerializeAsString();
322 }
323 
324 }  // namespace xla
325