1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/variant.h"
22 #include "pybind11/pybind11.h"
23 #include "tensorflow/core/profiler/convert/xplane_to_tools_data.h"
24 #include "tensorflow/core/profiler/rpc/profiler_server.h"
25 #include "tensorflow/python/lib/core/pybind11_status.h"
26 #include "tensorflow/python/profiler/internal/profiler_pywrap_impl.h"
27
28 namespace py = ::pybind11;
29
30 using ::tensorflow::profiler::pywrap::ProfilerSessionWrapper;
31
32 namespace {
33
34 // This must be called under GIL because it reads Python objects. Reading Python
35 // objects require GIL because the objects can be mutated by other Python
36 // threads. In addition, Python objects are reference counted; reading py::dict
37 // will increase its reference count.
ConvertDictToMap(const py::dict & dict)38 absl::flat_hash_map<std::string, absl::variant<int>> ConvertDictToMap(
39 const py::dict& dict) {
40 absl::flat_hash_map<std::string, absl::variant<int>> map;
41 for (const auto& kw : dict) {
42 if (!kw.second.is_none()) {
43 map.emplace(kw.first.cast<std::string>(), kw.second.cast<int>());
44 }
45 }
46 return map;
47 }
48
49 } // namespace
50
PYBIND11_MODULE(_pywrap_profiler,m)51 PYBIND11_MODULE(_pywrap_profiler, m) {
52 py::class_<ProfilerSessionWrapper> profiler_session_class(m,
53 "ProfilerSession");
54 profiler_session_class.def(py::init<>())
55 .def("start",
56 [](ProfilerSessionWrapper& wrapper, const char* logdir,
57 const py::dict& options) {
58 tensorflow::Status status;
59 absl::flat_hash_map<std::string, absl::variant<int>> opts =
60 ConvertDictToMap(options);
61 {
62 py::gil_scoped_release release;
63 status = wrapper.Start(logdir, opts);
64 }
65 // Py_INCREF and Py_DECREF must be called holding the GIL.
66 tensorflow::MaybeRaiseRegisteredFromStatus(status);
67 })
68 .def("stop",
69 [](ProfilerSessionWrapper& wrapper) {
70 tensorflow::string content;
71 tensorflow::Status status;
72 {
73 py::gil_scoped_release release;
74 status = wrapper.Stop(&content);
75 }
76 // Py_INCREF and Py_DECREF must be called holding the GIL.
77 tensorflow::MaybeRaiseRegisteredFromStatus(status);
78 // The content is not valid UTF-8. It must be converted to bytes.
79 return py::bytes(content);
80 })
81 .def("export_to_tb", [](ProfilerSessionWrapper& wrapper) {
82 tensorflow::Status status;
83 {
84 py::gil_scoped_release release;
85 status = wrapper.ExportToTensorBoard();
86 }
87 // Py_INCREF and Py_DECREF must be called holding the GIL.
88 tensorflow::MaybeRaiseRegisteredFromStatus(status);
89 });
90
91 m.def("start_server", [](int port) {
92 auto profiler_server =
93 absl::make_unique<tensorflow::profiler::ProfilerServer>();
94 profiler_server->StartProfilerServer(port);
95 // Intentionally release profiler server. Should transfer ownership to
96 // caller instead.
97 profiler_server.release();
98 });
99
100 m.def("trace",
101 [](const char* service_addr, const char* logdir,
102 const char* worker_list, bool include_dataset_ops, int duration_ms,
103 int num_tracing_attempts, py::dict options) {
104 tensorflow::Status status;
105 absl::flat_hash_map<std::string, absl::variant<int>> opts =
106 ConvertDictToMap(options);
107 {
108 py::gil_scoped_release release;
109 status = tensorflow::profiler::pywrap::Trace(
110 service_addr, logdir, worker_list, include_dataset_ops,
111 duration_ms, num_tracing_attempts, opts);
112 }
113 // Py_INCREF and Py_DECREF must be called holding the GIL.
114 tensorflow::MaybeRaiseRegisteredFromStatus(status);
115 });
116
117 m.def("monitor", [](const char* service_addr, int duration_ms,
118 int monitoring_level, bool display_timestamp) {
119 tensorflow::string content;
120 tensorflow::Status status;
121 {
122 py::gil_scoped_release release;
123 status = tensorflow::profiler::pywrap::Monitor(
124 service_addr, duration_ms, monitoring_level, display_timestamp,
125 &content);
126 }
127 // Py_INCREF and Py_DECREF must be called holding the GIL.
128 tensorflow::MaybeRaiseRegisteredFromStatus(status);
129 return content;
130 });
131
132 m.def("xspace_to_tools_data",
133 [](const py::list& xspace_path_list, const py::str& py_tool_name) {
134 std::vector<std::string> xspace_paths;
135 for (py::handle obj : xspace_path_list) {
136 xspace_paths.push_back(std::string(py::cast<py::str>(obj)));
137 }
138 std::string tool_name = std::string(py_tool_name);
139 auto tool_data_and_success =
140 tensorflow::profiler::ConvertMultiXSpacesToToolData(xspace_paths,
141 tool_name);
142 return py::make_tuple(py::bytes(tool_data_and_success.first),
143 py::bool_(tool_data_and_success.second));
144 });
145 };
146