1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ 17 #define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ 18 19 #include <cstddef> 20 #include <functional> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <unordered_set> 25 #include <vector> 26 27 #include "tensorflow/core/debug/debug_node_key.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/graph/graph.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/gtl/array_slice.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/util/event.pb.h" 34 35 namespace tensorflow { 36 37 Status ReadEventFromFile(const string& dump_file_path, Event* event); 38 39 struct DebugWatchAndURLSpec { DebugWatchAndURLSpecDebugWatchAndURLSpec40 DebugWatchAndURLSpec(const string& watch_key, const string& url, 41 const bool gated_grpc) 42 : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {} 43 44 const string watch_key; 45 const string url; 46 const bool gated_grpc; 47 }; 48 49 // TODO(cais): Put static functions and members in a namespace, not a class. 50 class DebugIO { 51 public: 52 static const char* const kDebuggerPluginName; 53 54 static const char* const kCoreMetadataTag; 55 static const char* const kGraphTag; 56 static const char* const kHashTag; 57 58 static const char* const kFileURLScheme; 59 static const char* const kGrpcURLScheme; 60 static const char* const kMemoryURLScheme; 61 62 static Status PublishDebugMetadata( 63 const int64 global_step, const int64 session_run_index, 64 const int64 executor_step_index, const std::vector<string>& input_names, 65 const std::vector<string>& output_names, 66 const std::vector<string>& target_nodes, 67 const std::unordered_set<string>& debug_urls); 68 69 // Publishes a tensor to a debug target URL. 70 // 71 // Args: 72 // debug_node_key: A DebugNodeKey identifying the debug node. 73 // tensor: The Tensor object being published. 74 // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). 75 // debug_urls: An array of debug target URLs, e.g., 76 // "file:///foo/tfdbg_dump", "grpc://localhost:11011" 77 // gated_grpc: Whether this call is subject to gRPC gating. 78 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, 79 const Tensor& tensor, 80 const uint64 wall_time_us, 81 const gtl::ArraySlice<string>& debug_urls, 82 const bool gated_grpc); 83 84 // Convenience overload of the method above for no gated_grpc by default. 85 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, 86 const Tensor& tensor, 87 const uint64 wall_time_us, 88 const gtl::ArraySlice<string>& debug_urls); 89 90 // Publishes a graph to a set of debug URLs. 91 // 92 // Args: 93 // graph: The graph to be published. 94 // debug_urls: The set of debug URLs to publish the graph to. 95 static Status PublishGraph(const Graph& graph, const string& device_name, 96 const std::unordered_set<string>& debug_urls); 97 98 // Determines whether a copy node needs to perform deep-copy of input tensor. 99 // 100 // The input arguments contain sufficient information about the attached 101 // downstream debug ops for this method to determine whether all the said 102 // ops are disabled given the current status of the gRPC gating. 103 // 104 // Args: 105 // specs: A vector of DebugWatchAndURLSpec carrying information about the 106 // debug ops attached to the Copy node, their debug URLs and whether 107 // they have the attribute value gated_grpc == True. 108 // 109 // Returns: 110 // Whether any of the attached downstream debug ops is enabled given the 111 // current status of the gRPC gating. 112 static bool IsCopyNodeGateOpen( 113 const std::vector<DebugWatchAndURLSpec>& specs); 114 115 // Determines whether a debug node needs to proceed given the current gRPC 116 // gating status. 117 // 118 // Args: 119 // watch_key: debug tensor watch key, in the format of 120 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". 121 // debug_urls: the debug URLs of the debug node. 122 // 123 // Returns: 124 // Whether this debug op should proceed. 125 static bool IsDebugNodeGateOpen(const string& watch_key, 126 const std::vector<string>& debug_urls); 127 128 // Determines whether debug information should be sent through a grpc:// 129 // debug URL given the current gRPC gating status. 130 // 131 // Args: 132 // watch_key: debug tensor watch key, in the format of 133 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". 134 // debug_url: the debug URL, e.g., "grpc://localhost:3333", 135 // "file:///tmp/tfdbg_1". 136 // 137 // Returns: 138 // Whether the sending of debug data to the debug_url should 139 // proceed. 140 static bool IsDebugURLGateOpen(const string& watch_key, 141 const string& debug_url); 142 143 static Status CloseDebugURL(const string& debug_url); 144 }; 145 146 // Helper class for debug ops. 147 class DebugFileIO { 148 public: 149 // Encapsulates the Tensor in an Event protobuf and write it to a directory. 150 // The actual path of the dump file will be a contactenation of 151 // dump_root_dir, tensor_name, along with the wall_time. 152 // 153 // For example: 154 // let dump_root_dir = "/tmp/tfdbg_dump", 155 // node_name = "foo/bar", 156 // output_slot = 0, 157 // debug_op = DebugIdentity, 158 // and wall_time_us = 1467891234512345, 159 // the dump file will be generated at path: 160 // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345. 161 // 162 // Args: 163 // debug_node_key: A DebugNodeKey identifying the debug node. 164 // wall_time_us: Wall time at which the Tensor is generated during graph 165 // execution. Unit: microseconds (us). 166 // dump_root_dir: Root directory for dumping the tensor. 167 // dump_file_path: The actual dump file path (passed as reference). 168 static Status DumpTensorToDir(const DebugNodeKey& debug_node_key, 169 const Tensor& tensor, const uint64 wall_time_us, 170 const string& dump_root_dir, 171 string* dump_file_path); 172 173 // Get the full path to the dump file. 174 // 175 // Args: 176 // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump 177 // node_name: Name of the node from which the dumped tensor is generated, 178 // e.g., foo/bar/node_a 179 // output_slot: Output slot index of the said node, e.g., 0. 180 // debug_op: Name of the debug op, e.g., DebugIdentity. 181 // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). 182 static string GetDumpFilePath(const string& dump_root_dir, 183 const DebugNodeKey& debug_node_key, 184 const uint64 wall_time_us); 185 186 // Dumps an Event proto to a file. 187 // 188 // Args: 189 // event_prot: The Event proto to be dumped. 190 // dir_name: Directory path. 191 // file_name: Base file name. 192 static Status DumpEventProtoToFile(const Event& event_proto, 193 const string& dir_name, 194 const string& file_name); 195 196 // Request additional bytes to be dumped to the file system. 197 // 198 // Does not actually dump the bytes, but instead just performs the 199 // bookkeeping necessary to prevent the total dumped amount of data from 200 // exceeding the limit (default 100 GBytes or set customly through the 201 // environment variable TFDBG_DISK_BYTES_LIMIT). 202 // 203 // Args: 204 // bytes: Number of bytes to request. 205 // 206 // Returns: 207 // Whether the request is approved given the total dumping 208 // limit. 209 static bool requestDiskByteUsage(uint64 bytes); 210 211 // Reset the disk byte usage to zero. 212 static void resetDiskByteUsage(); 213 214 static uint64 globalDiskBytesLimit; 215 216 private: 217 // Encapsulates the Tensor in an Event protobuf and write it to file. 218 static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, 219 const Tensor& tensor, 220 const uint64 wall_time_us, 221 const string& file_path); 222 223 // Implemented ad hoc here for now. 224 // TODO(cais): Replace with shared implementation once http://b/30497715 is 225 // fixed. 226 static Status RecursiveCreateDir(Env* env, const string& dir); 227 228 // Tracks how much disk has been used so far. 229 static uint64 diskBytesUsed; 230 // Mutex for thread-safe access to diskBytesUsed. 231 static mutex bytes_mu; 232 // Default limit for the disk space. 233 static const uint64 defaultGlobalDiskBytesLimit; 234 235 friend class DiskUsageLimitTest; 236 }; 237 238 } // namespace tensorflow 239 240 namespace std { 241 242 template <> 243 struct hash<::tensorflow::DebugNodeKey> { 244 size_t operator()(const ::tensorflow::DebugNodeKey& k) const { 245 return ::tensorflow::Hash64( 246 ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":", 247 k.output_slot, ":", k.debug_op, ":")); 248 } 249 }; 250 251 } // namespace std 252 253 // TODO(cais): Support grpc:// debug URLs in open source once Python grpc 254 // genrule becomes available. See b/23796275. 255 #ifndef PLATFORM_WINDOWS 256 #include "grpcpp/channel.h" 257 #include "tensorflow/core/debug/debug_service.grpc.pb.h" 258 259 namespace tensorflow { 260 261 class DebugGrpcChannel { 262 public: 263 // Constructor of DebugGrpcChannel. 264 // 265 // Args: 266 // server_stream_addr: Address (host name and port) of the debug stream 267 // server implementing the EventListener service (see 268 // debug_service.proto). E.g., "127.0.0.1:12345". 269 DebugGrpcChannel(const string& server_stream_addr); 270 271 virtual ~DebugGrpcChannel() {} 272 273 // Attempt to establish connection with server. 274 // 275 // Args: 276 // timeout_micros: Timeout (in microseconds) for the attempt to establish 277 // the connection. 278 // 279 // Returns: 280 // OK Status iff connection is successfully established before timeout, 281 // otherwise return an error Status. 282 Status Connect(const int64 timeout_micros); 283 284 // Write an Event proto to the debug gRPC stream. 285 // 286 // Thread-safety: Safe with respect to other calls to the same method and 287 // calls to ReadEventReply() and Close(). 288 // 289 // Args: 290 // event: The event proto to be written to the stream. 291 // 292 // Returns: 293 // True iff the write is successful. 294 bool WriteEvent(const Event& event); 295 296 // Read an EventReply proto from the debug gRPC stream. 297 // 298 // This method blocks and waits for an EventReply from the server. 299 // Thread-safety: Safe with respect to other calls to the same method and 300 // calls to WriteEvent() and Close(). 301 // 302 // Args: 303 // event_reply: the to-be-modified EventReply proto passed as reference. 304 // 305 // Returns: 306 // True iff the read is successful. 307 bool ReadEventReply(EventReply* event_reply); 308 309 // Receive and process EventReply protos from the gRPC debug server. 310 // 311 // The processing includes setting debug watch key states using the 312 // DebugOpStateChange fields of the EventReply. 313 // 314 // Args: 315 // max_replies: Maximum number of replies to receive. Will receive all 316 // remaining replies iff max_replies == 0. 317 void ReceiveAndProcessEventReplies(size_t max_replies); 318 319 // Receive EventReplies from server (if any) and close the stream and the 320 // channel. 321 Status ReceiveServerRepliesAndClose(); 322 323 private: 324 string server_stream_addr_; 325 string url_; 326 ::grpc::ClientContext ctx_; 327 std::shared_ptr<::grpc::Channel> channel_; 328 std::unique_ptr<EventListener::Stub> stub_; 329 std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>> 330 reader_writer_; 331 332 mutex mu_; 333 }; 334 335 class DebugGrpcIO { 336 public: 337 static const size_t kGrpcMessageSizeLimitBytes; 338 static const size_t kGrpcMaxVarintLengthSize; 339 340 // Sends a tensor through a debug gRPC stream. 341 static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key, 342 const Tensor& tensor, 343 const uint64 wall_time_us, 344 const string& grpc_stream_url, 345 const bool gated); 346 347 // Sends an Event proto through a debug gRPC stream. 348 // Thread-safety: Safe with respect to other calls to the same method and 349 // calls to CloseGrpcStream(). 350 // 351 // Args: 352 // event_proto: The Event proto to be sent. 353 // grpc_stream_url: The grpc:// URL of the stream to use, e.g., 354 // "grpc://localhost:11011", "localhost:22022". 355 // receive_reply: Whether an EventReply proto will be read after event_proto 356 // is sent and before the function returns. 357 // 358 // Returns: 359 // The Status of the operation. 360 static Status SendEventProtoThroughGrpcStream( 361 const Event& event_proto, const string& grpc_stream_url, 362 const bool receive_reply = false); 363 364 // Receive an EventReply proto through a debug gRPC stream. 365 static Status ReceiveEventReplyProtoThroughGrpcStream( 366 EventReply* event_reply, const string& grpc_stream_url); 367 368 // Check whether a debug watch key is read-activated at a given gRPC URL. 369 static bool IsReadGateOpen(const string& grpc_debug_url, 370 const string& watch_key); 371 372 // Check whether a debug watch key is write-activated (i.e., read- and 373 // write-activated) at a given gRPC URL. 374 static bool IsWriteGateOpen(const string& grpc_debug_url, 375 const string& watch_key); 376 377 // Closes a gRPC stream to the given address, if it exists. 378 // Thread-safety: Safe with respect to other calls to the same method and 379 // calls to SendTensorThroughGrpcStream(). 380 static Status CloseGrpcStream(const string& grpc_stream_url); 381 382 // Set the gRPC state of a debug node key. 383 // TODO(cais): Include device information in watch_key. 384 static void SetDebugNodeKeyGrpcState( 385 const string& grpc_debug_url, const string& watch_key, 386 const EventReply::DebugOpStateChange::State new_state); 387 388 private: 389 using DebugNodeName2State = 390 std::unordered_map<string, EventReply::DebugOpStateChange::State>; 391 392 // Returns a global map from grpc debug URLs to the corresponding 393 // DebugGrpcChannels. 394 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 395 GetStreamChannels(); 396 397 // Get a DebugGrpcChannel object at a given URL, creating one if necessary. 398 // 399 // Args: 400 // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064" 401 // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a 402 // a pointer to the pointer. The DebugGrpcChannel object is owned 403 // statically elsewhere, not by the caller of this function. 404 // 405 // Returns: 406 // Status of this operation. 407 static Status GetOrCreateDebugGrpcChannel( 408 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); 409 410 // Returns a map from debug URL to a map from debug op name to enabled state. 411 static std::unordered_map<string, DebugNodeName2State>* 412 GetEnabledDebugOpStates(); 413 414 // Returns a map from debug op names to enabled state, for a given debug URL. 415 static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( 416 const string& grpc_debug_url); 417 418 // Clear enabled debug op state from all debug URLs (if any). 419 static void ClearEnabledWatchKeys(); 420 421 static mutex streams_mu; 422 static int64 channel_connection_timeout_micros; 423 424 friend class GrpcDebugTest; 425 friend class DebugNumericSummaryOpTest; 426 }; 427 428 } // namespace tensorflow 429 #endif // #ifndef(PLATFORM_WINDOWS) 430 431 #endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ 432