1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/debug/debug_io_utils.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 #include <cmath>
21 #include <cstdlib>
22 #include <cstring>
23 #include <limits>
24 #include <utility>
25 #include <vector>
26 
27 #ifndef PLATFORM_WINDOWS
28 #include "grpcpp/create_channel.h"
29 #else
30 // winsock2.h is used in grpc, so Ws2_32.lib is needed
31 #pragma comment(lib, "Ws2_32.lib")
32 #endif  // #ifndef PLATFORM_WINDOWS
33 
34 #include "tensorflow/core/debug/debug_callback_registry.h"
35 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/summary.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/lib/core/bits.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/protobuf.h"
46 #include "tensorflow/core/util/event.pb.h"
47 
48 #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
49   return errors::Unimplemented(              \
50       kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
51 
52 namespace tensorflow {
53 
54 namespace {
55 
56 // Creates an Event proto representing a chunk of a Tensor. This method only
57 // populates the field of the Event proto that represent the envelope
58 // information (e.g., timestamp, device_name, num_chunks, chunk_index, dtype,
59 // shape). It does not set the value.tensor field, which should be set by the
60 // caller separately.
PrepareChunkEventProto(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t num_chunks,const size_t chunk_index,const DataType & tensor_dtype,const TensorShapeProto & tensor_shape)61 Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key,
62                              const uint64 wall_time_us, const size_t num_chunks,
63                              const size_t chunk_index,
64                              const DataType& tensor_dtype,
65                              const TensorShapeProto& tensor_shape) {
66   Event event;
67   event.set_wall_time(static_cast<double>(wall_time_us));
68   Summary::Value* value = event.mutable_summary()->add_value();
69 
70   // Create the debug node_name in the Summary proto.
71   // For example, if tensor_name = "foo/node_a:0", and the debug_op is
72   // "DebugIdentity", the debug node_name in the Summary proto will be
73   // "foo/node_a:0:DebugIdentity".
74   value->set_node_name(debug_node_key.debug_node_name);
75 
76   // Tag by the node name. This allows TensorBoard to quickly fetch data
77   // per op.
78   value->set_tag(debug_node_key.node_name);
79 
80   // Store data within debugger metadata to be stored for each event.
81   third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
82   metadata.set_device(debug_node_key.device_name);
83   metadata.set_output_slot(debug_node_key.output_slot);
84   metadata.set_num_chunks(num_chunks);
85   metadata.set_chunk_index(chunk_index);
86 
87   // Encode the data in JSON.
88   string json_output;
89   tensorflow::protobuf::util::JsonPrintOptions json_options;
90   json_options.always_print_primitive_fields = true;
91   auto status = tensorflow::protobuf::util::MessageToJsonString(
92       metadata, &json_output, json_options);
93   if (status.ok()) {
94     // Store summary metadata. Set the plugin to use this data as "debugger".
95     SummaryMetadata::PluginData* plugin_data =
96         value->mutable_metadata()->mutable_plugin_data();
97     plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName);
98     plugin_data->set_content(json_output);
99   } else {
100     LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. "
101                  << "The debug_node_name is " << debug_node_key.debug_node_name
102                  << ".";
103   }
104 
105   value->mutable_tensor()->set_dtype(tensor_dtype);
106   *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape;
107 
108   return event;
109 }
110 
111 // Translates the length of a string to number of bytes when the string is
112 // encoded as bytes in protobuf. Note that this makes a conservative estimate
113 // (i.e., an estimate that is usually too large, but never too small under the
114 // gRPC message size limit) of the Varint-encoded length, to workaround the lack
115 // of a portable length function.
StringValMaxBytesInProto(const string & str)116 const size_t StringValMaxBytesInProto(const string& str) {
117 #if defined(PLATFORM_GOOGLE)
118   return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize;
119 #else
120   return str.size();
121 #endif
122 }
123 
124 // Breaks a string Tensor (represented as a TensorProto) as a vector of Event
125 // protos.
WrapStringTensorAsEvents(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t chunk_size_limit,TensorProto * tensor_proto,std::vector<Event> * events)126 Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key,
127                                 const uint64 wall_time_us,
128                                 const size_t chunk_size_limit,
129                                 TensorProto* tensor_proto,
130                                 std::vector<Event>* events) {
131   const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val();
132   const size_t num_strs = strs.size();
133   const size_t chunk_size_ub = chunk_size_limit > 0
134                                    ? chunk_size_limit
135                                    : std::numeric_limits<size_t>::max();
136 
137   // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges:
138   //   [0:a), [a:b), [c:<end>].
139   std::vector<size_t> cutoffs;
140   size_t chunk_size = 0;
141   for (size_t i = 0; i < num_strs; ++i) {
142     // Take into account the extra bytes in proto buffer.
143     if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
144       return errors::FailedPrecondition(
145           "string value at index ", i, " from debug node ",
146           debug_node_key.debug_node_name,
147           " does not fit gRPC message size limit (", chunk_size_ub, ")");
148     }
149     if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
150       cutoffs.push_back(i);
151       chunk_size = 0;
152     }
153     chunk_size += StringValMaxBytesInProto(strs[i]);
154   }
155   cutoffs.push_back(num_strs);
156   const size_t num_chunks = cutoffs.size();
157 
158   for (size_t i = 0; i < num_chunks; ++i) {
159     Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
160                                          num_chunks, i, tensor_proto->dtype(),
161                                          tensor_proto->tensor_shape());
162     Summary::Value* value = event.mutable_summary()->mutable_value(0);
163 
164     if (cutoffs.size() == 1) {
165       value->mutable_tensor()->mutable_string_val()->Swap(
166           tensor_proto->mutable_string_val());
167     } else {
168       const size_t begin = (i == 0) ? 0 : cutoffs[i - 1];
169       const size_t end = cutoffs[i];
170       for (size_t j = begin; j < end; ++j) {
171         value->mutable_tensor()->add_string_val(strs[j]);
172       }
173     }
174 
175     events->push_back(std::move(event));
176   }
177 
178   return Status::OK();
179 }
180 
181 // Encapsulates the tensor value inside a vector of Event protos. Large tensors
182 // are broken up to multiple protos to fit the chunk_size_limit. In each Event
183 // proto the field summary.tensor carries the content of the tensor.
184 // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a
185 // length-1 vector will be returned, regardless of the size of the tensor.
WrapTensorAsEvents(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const size_t chunk_size_limit,std::vector<Event> * events)186 Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key,
187                           const Tensor& tensor, const uint64 wall_time_us,
188                           const size_t chunk_size_limit,
189                           std::vector<Event>* events) {
190   TensorProto tensor_proto;
191   if (tensor.dtype() == DT_STRING) {
192     // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can
193     // convert the TensorProto to string-type numpy array. MakeNdarray does not
194     // work with strings encoded by AsProtoTensorContent() in tensor_content.
195     tensor.AsProtoField(&tensor_proto);
196 
197     TF_RETURN_IF_ERROR(WrapStringTensorAsEvents(
198         debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events));
199   } else {
200     tensor.AsProtoTensorContent(&tensor_proto);
201 
202     const size_t total_length = tensor_proto.tensor_content().size();
203     const size_t chunk_size_ub =
204         chunk_size_limit > 0 ? chunk_size_limit : total_length;
205     const size_t num_chunks =
206         (total_length == 0)
207             ? 1
208             : (total_length + chunk_size_ub - 1) / chunk_size_ub;
209     for (size_t i = 0; i < num_chunks; ++i) {
210       const size_t pos = i * chunk_size_ub;
211       const size_t len =
212           (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub;
213       Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
214                                            num_chunks, i, tensor_proto.dtype(),
215                                            tensor_proto.tensor_shape());
216       event.mutable_summary()
217           ->mutable_value(0)
218           ->mutable_tensor()
219           ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len));
220       events->push_back(std::move(event));
221     }
222   }
223 
224   return Status::OK();
225 }
226 
227 // Appends an underscore and a timestamp to a file path. If the path already
228 // exists on the file system, append a hyphen and a 1-up index. Consecutive
229 // values of the index will be tried until the first unused one is found.
230 // TOCTOU race condition is not of concern here due to the fact that tfdbg
231 // sets parallel_iterations attribute of all while_loops to 1 to prevent
232 // the same node from between executed multiple times concurrently.
AppendTimestampToFilePath(const string & in,const uint64 timestamp)233 string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
234   string out = strings::StrCat(in, "_", timestamp);
235 
236   uint64 i = 1;
237   while (Env::Default()->FileExists(out).ok()) {
238     out = strings::StrCat(in, "_", timestamp, "-", i);
239     ++i;
240   }
241   return out;
242 }
243 
244 #ifndef PLATFORM_WINDOWS
245 // Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
246 // conforming to the gRPC message size limit.
PublishEncodedGraphDefInChunks(const string & encoded_graph_def,const string & device_name,const int64 wall_time,const string & debug_url)247 Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
248                                       const string& device_name,
249                                       const int64 wall_time,
250                                       const string& debug_url) {
251   const uint64 hash = ::tensorflow::Hash64(encoded_graph_def);
252   const size_t total_length = encoded_graph_def.size();
253   const size_t num_chunks =
254       static_cast<size_t>(std::ceil(static_cast<float>(total_length) /
255                                     DebugGrpcIO::kGrpcMessageSizeLimitBytes));
256   for (size_t i = 0; i < num_chunks; ++i) {
257     const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes;
258     const size_t len = (i == num_chunks - 1)
259                            ? (total_length - pos)
260                            : DebugGrpcIO::kGrpcMessageSizeLimitBytes;
261     Event event;
262     event.set_wall_time(static_cast<double>(wall_time));
263     // Prefix the chunk with
264     //   <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|.
265     // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks
266     // and chunk_index, instead.
267     event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
268                                         "|", i, "|", num_chunks, "|",
269                                         encoded_graph_def.substr(pos, len)));
270     const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
271         event, debug_url, num_chunks - 1 == i);
272     if (!s.ok()) {
273       return errors::FailedPrecondition(
274           "Failed to send chunk ", i, " of ", num_chunks,
275           " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
276           "due to: ", s.error_message());
277     }
278   }
279   return Status::OK();
280 }
281 #endif  // #ifndef PLATFORM_WINDOWS
282 
283 }  // namespace
284 
285 const char* const DebugIO::kDebuggerPluginName = "debugger";
286 
287 const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
288 
289 const char* const DebugIO::kGraphTag = "graph_";
290 
291 const char* const DebugIO::kHashTag = "hash";
292 
ReadEventFromFile(const string & dump_file_path,Event * event)293 Status ReadEventFromFile(const string& dump_file_path, Event* event) {
294   Env* env(Env::Default());
295 
296   string content;
297   uint64 file_size = 0;
298 
299   Status s = env->GetFileSize(dump_file_path, &file_size);
300   if (!s.ok()) {
301     return s;
302   }
303 
304   content.resize(file_size);
305 
306   std::unique_ptr<RandomAccessFile> file;
307   s = env->NewRandomAccessFile(dump_file_path, &file);
308   if (!s.ok()) {
309     return s;
310   }
311 
312   StringPiece result;
313   s = file->Read(0, file_size, &result, &(content)[0]);
314   if (!s.ok()) {
315     return s;
316   }
317 
318   event->ParseFromString(content);
319   return Status::OK();
320 }
321 
322 const char* const DebugIO::kFileURLScheme = "file://";
323 const char* const DebugIO::kGrpcURLScheme = "grpc://";
324 const char* const DebugIO::kMemoryURLScheme = "memcbk://";
325 
326 // Publishes debug metadata to a set of debug URLs.
PublishDebugMetadata(const int64 global_step,const int64 session_run_index,const int64 executor_step_index,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,const std::unordered_set<string> & debug_urls)327 Status DebugIO::PublishDebugMetadata(
328     const int64 global_step, const int64 session_run_index,
329     const int64 executor_step_index, const std::vector<string>& input_names,
330     const std::vector<string>& output_names,
331     const std::vector<string>& target_nodes,
332     const std::unordered_set<string>& debug_urls) {
333   std::ostringstream oss;
334 
335   // Construct a JSON string to carry the metadata.
336   oss << "{";
337   oss << "\"global_step\":" << global_step << ",";
338   oss << "\"session_run_index\":" << session_run_index << ",";
339   oss << "\"executor_step_index\":" << executor_step_index << ",";
340   oss << "\"input_names\":[";
341   for (size_t i = 0; i < input_names.size(); ++i) {
342     oss << "\"" << input_names[i] << "\"";
343     if (i < input_names.size() - 1) {
344       oss << ",";
345     }
346   }
347   oss << "],";
348   oss << "\"output_names\":[";
349   for (size_t i = 0; i < output_names.size(); ++i) {
350     oss << "\"" << output_names[i] << "\"";
351     if (i < output_names.size() - 1) {
352       oss << ",";
353     }
354   }
355   oss << "],";
356   oss << "\"target_nodes\":[";
357   for (size_t i = 0; i < target_nodes.size(); ++i) {
358     oss << "\"" << target_nodes[i] << "\"";
359     if (i < target_nodes.size() - 1) {
360       oss << ",";
361     }
362   }
363   oss << "]";
364   oss << "}";
365 
366   const string json_metadata = oss.str();
367   Event event;
368   event.set_wall_time(static_cast<double>(Env::Default()->NowMicros()));
369   LogMessage* log_message = event.mutable_log_message();
370   log_message->set_message(json_metadata);
371 
372   Status status;
373   for (const string& url : debug_urls) {
374     if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
375 #ifndef PLATFORM_WINDOWS
376       Event grpc_event;
377 
378       // Determine the path (if any) in the grpc:// URL, and add it as a field
379       // of the JSON string.
380       const string address = url.substr(strlen(DebugIO::kFileURLScheme));
381       const string path = address.find("/") == string::npos
382                               ? ""
383                               : address.substr(address.find("/"));
384       grpc_event.set_wall_time(event.wall_time());
385       LogMessage* log_message_grpc = grpc_event.mutable_log_message();
386       log_message_grpc->set_message(
387           strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1),
388                           ",\"grpc_path\":\"", path, "\"}"));
389 
390       status.Update(
391           DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
392 #else
393       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
394 #endif
395     } else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
396       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
397       const string core_metadata_path = AppendTimestampToFilePath(
398           io::JoinPath(
399               dump_root_dir,
400               strings::StrCat(DebugNodeKey::kMetadataFilePrefix,
401                               DebugIO::kCoreMetadataTag, "sessionrun",
402                               strings::Printf("%.14lld", session_run_index))),
403           Env::Default()->NowMicros());
404       status.Update(DebugFileIO::DumpEventProtoToFile(
405           event, string(io::Dirname(core_metadata_path)),
406           string(io::Basename(core_metadata_path))));
407     }
408   }
409 
410   return status;
411 }
412 
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> & debug_urls,const bool gated_grpc)413 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
414                                    const Tensor& tensor,
415                                    const uint64 wall_time_us,
416                                    const gtl::ArraySlice<string>& debug_urls,
417                                    const bool gated_grpc) {
418   int32 num_failed_urls = 0;
419   std::vector<Status> fail_statuses;
420   for (const string& url : debug_urls) {
421     if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
422       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
423 
424       const int64 tensorBytes =
425           tensor.IsInitialized() ? tensor.TotalBytes() : 0;
426       if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
427         return errors::ResourceExhausted(
428             "TensorFlow Debugger has exhausted file-system byte-size "
429             "allowance (",
430             DebugFileIO::globalDiskBytesLimit, "), therefore it cannot ",
431             "dump an additional ", tensorBytes, " byte(s) of tensor data ",
432             "for the debug tensor ", debug_node_key.node_name, ":",
433             debug_node_key.output_slot, ". You may use the environment ",
434             "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
435       }
436 
437       Status s = DebugFileIO::DumpTensorToDir(
438           debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
439       if (!s.ok()) {
440         num_failed_urls++;
441         fail_statuses.push_back(s);
442       }
443     } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
444 #ifndef PLATFORM_WINDOWS
445       Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
446           debug_node_key, tensor, wall_time_us, url, gated_grpc);
447 
448       if (!s.ok()) {
449         num_failed_urls++;
450         fail_statuses.push_back(s);
451       }
452 #else
453       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
454 #endif
455     } else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
456       const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
457       auto* callback_registry = DebugCallbackRegistry::singleton();
458       auto* callback = callback_registry->GetCallback(dump_root_dir);
459       CHECK(callback) << "No callback registered for: " << dump_root_dir;
460       (*callback)(debug_node_key, tensor);
461     } else {
462       return Status(error::UNAVAILABLE,
463                     strings::StrCat("Invalid debug target URL: ", url));
464     }
465   }
466 
467   if (num_failed_urls == 0) {
468     return Status::OK();
469   } else {
470     string error_message = strings::StrCat(
471         "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
472         " debug target URLs failed, due to the following errors:");
473     for (Status& status : fail_statuses) {
474       error_message =
475           strings::StrCat(error_message, " ", status.error_message(), ";");
476     }
477 
478     return Status(error::INTERNAL, error_message);
479   }
480 }
481 
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> & debug_urls)482 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
483                                    const Tensor& tensor,
484                                    const uint64 wall_time_us,
485                                    const gtl::ArraySlice<string>& debug_urls) {
486   return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls,
487                             false);
488 }
489 
PublishGraph(const Graph & graph,const string & device_name,const std::unordered_set<string> & debug_urls)490 Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
491                              const std::unordered_set<string>& debug_urls) {
492   GraphDef graph_def;
493   graph.ToGraphDef(&graph_def);
494 
495   string buf;
496   graph_def.SerializeToString(&buf);
497 
498   const int64 now_micros = Env::Default()->NowMicros();
499   Event event;
500   event.set_wall_time(static_cast<double>(now_micros));
501   event.set_graph_def(buf);
502 
503   Status status = Status::OK();
504   for (const string& debug_url : debug_urls) {
505     if (debug_url.find(kFileURLScheme) == 0) {
506       const string dump_root_dir =
507           io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
508                        DebugNodeKey::DeviceNameToDevicePath(device_name));
509       const uint64 graph_hash = ::tensorflow::Hash64(buf);
510       const string file_name =
511           strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
512                           DebugIO::kHashTag, graph_hash, "_", now_micros);
513 
514       status.Update(
515           DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
516     } else if (debug_url.find(kGrpcURLScheme) == 0) {
517 #ifndef PLATFORM_WINDOWS
518       status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
519                                                    debug_url));
520 #else
521       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
522 #endif
523     }
524   }
525 
526   return status;
527 }
528 
IsCopyNodeGateOpen(const std::vector<DebugWatchAndURLSpec> & specs)529 bool DebugIO::IsCopyNodeGateOpen(
530     const std::vector<DebugWatchAndURLSpec>& specs) {
531 #ifndef PLATFORM_WINDOWS
532   for (const DebugWatchAndURLSpec& spec : specs) {
533     if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
534                                              DebugIO::kGrpcURLScheme)) {
535       return true;
536     } else {
537       if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
538         return true;
539       }
540     }
541   }
542   return false;
543 #else
544   return true;
545 #endif
546 }
547 
IsDebugNodeGateOpen(const string & watch_key,const std::vector<string> & debug_urls)548 bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
549                                   const std::vector<string>& debug_urls) {
550 #ifndef PLATFORM_WINDOWS
551   for (const string& debug_url : debug_urls) {
552     if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
553                           DebugIO::kGrpcURLScheme)) {
554       return true;
555     } else {
556       if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
557         return true;
558       }
559     }
560   }
561   return false;
562 #else
563   return true;
564 #endif
565 }
566 
IsDebugURLGateOpen(const string & watch_key,const string & debug_url)567 bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
568                                  const string& debug_url) {
569 #ifndef PLATFORM_WINDOWS
570   if (debug_url.find(kGrpcURLScheme) != 0) {
571     return true;
572   } else {
573     return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
574   }
575 #else
576   return true;
577 #endif
578 }
579 
CloseDebugURL(const string & debug_url)580 Status DebugIO::CloseDebugURL(const string& debug_url) {
581   if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
582 #ifndef PLATFORM_WINDOWS
583     return DebugGrpcIO::CloseGrpcStream(debug_url);
584 #else
585     GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
586 #endif
587   } else {
588     // No-op for non-gRPC URLs.
589     return Status::OK();
590   }
591 }
592 
DumpTensorToDir(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & dump_root_dir,string * dump_file_path)593 Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
594                                     const Tensor& tensor,
595                                     const uint64 wall_time_us,
596                                     const string& dump_root_dir,
597                                     string* dump_file_path) {
598   const string file_path =
599       GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us);
600 
601   if (dump_file_path != nullptr) {
602     *dump_file_path = file_path;
603   }
604 
605   return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
606 }
607 
GetDumpFilePath(const string & dump_root_dir,const DebugNodeKey & debug_node_key,const uint64 wall_time_us)608 string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
609                                     const DebugNodeKey& debug_node_key,
610                                     const uint64 wall_time_us) {
611   return AppendTimestampToFilePath(
612       io::JoinPath(dump_root_dir, debug_node_key.device_path,
613                    strings::StrCat(debug_node_key.node_name, "_",
614                                    debug_node_key.output_slot, "_",
615                                    debug_node_key.debug_op)),
616       wall_time_us);
617 }
618 
DumpEventProtoToFile(const Event & event_proto,const string & dir_name,const string & file_name)619 Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
620                                          const string& dir_name,
621                                          const string& file_name) {
622   Env* env(Env::Default());
623 
624   Status s = RecursiveCreateDir(env, dir_name);
625   if (!s.ok()) {
626     return Status(error::FAILED_PRECONDITION,
627                   strings::StrCat("Failed to create directory  ", dir_name,
628                                   ", due to: ", s.error_message()));
629   }
630 
631   const string file_path = io::JoinPath(dir_name, file_name);
632 
633   string event_str;
634   event_proto.SerializeToString(&event_str);
635 
636   std::unique_ptr<WritableFile> f = nullptr;
637   TF_CHECK_OK(env->NewWritableFile(file_path, &f));
638   f->Append(event_str).IgnoreError();
639   TF_CHECK_OK(f->Close());
640 
641   return Status::OK();
642 }
643 
DumpTensorToEventFile(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & file_path)644 Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
645                                           const Tensor& tensor,
646                                           const uint64 wall_time_us,
647                                           const string& file_path) {
648   std::vector<Event> events;
649   TF_RETURN_IF_ERROR(
650       WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
651   return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
652                               string(io::Basename(file_path)));
653 }
654 
RecursiveCreateDir(Env * env,const string & dir)655 Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
656   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
657     // The path already exists as a directory. Return OK right away.
658     return Status::OK();
659   }
660 
661   string parent_dir(io::Dirname(dir));
662   if (!env->FileExists(parent_dir).ok()) {
663     // The parent path does not exist yet, create it first.
664     Status s = RecursiveCreateDir(env, parent_dir);  // Recursive call
665     if (!s.ok()) {
666       return Status(
667           error::FAILED_PRECONDITION,
668           strings::StrCat("Failed to create directory  ", parent_dir));
669     }
670   } else if (env->FileExists(parent_dir).ok() &&
671              !env->IsDirectory(parent_dir).ok()) {
672     // The path exists, but it is a file.
673     return Status(error::FAILED_PRECONDITION,
674                   strings::StrCat("Failed to create directory  ", parent_dir,
675                                   " because the path exists as a file "));
676   }
677 
678   env->CreateDir(dir).IgnoreError();
679   // Guard against potential race in creating directories by doing a check
680   // after the CreateDir call.
681   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
682     return Status::OK();
683   } else {
684     return Status(error::ABORTED,
685                   strings::StrCat("Failed to create directory  ", parent_dir));
686   }
687 }
688 
689 // Default total disk usage limit: 100 GBytes
690 const uint64 DebugFileIO::defaultGlobalDiskBytesLimit = 107374182400L;
691 uint64 DebugFileIO::globalDiskBytesLimit = 0;
692 uint64 DebugFileIO::diskBytesUsed = 0;
693 
694 mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
695 
requestDiskByteUsage(uint64 bytes)696 bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
697   mutex_lock l(bytes_mu);
698   if (globalDiskBytesLimit == 0) {
699     const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
700     if (env_tfdbg_disk_bytes_limit == nullptr ||
701         strlen(env_tfdbg_disk_bytes_limit) == 0) {
702       globalDiskBytesLimit = defaultGlobalDiskBytesLimit;
703     } else {
704       strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
705                              &globalDiskBytesLimit);
706     }
707   }
708 
709   if (bytes == 0) {
710     return true;
711   }
712   if (diskBytesUsed + bytes < globalDiskBytesLimit) {
713     diskBytesUsed += bytes;
714     return true;
715   } else {
716     return false;
717   }
718 }
719 
resetDiskByteUsage()720 void DebugFileIO::resetDiskByteUsage() {
721   mutex_lock l(bytes_mu);
722   diskBytesUsed = 0;
723 }
724 
725 #ifndef PLATFORM_WINDOWS
DebugGrpcChannel(const string & server_stream_addr)726 DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
727     : server_stream_addr_(server_stream_addr),
728       url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
729 
Connect(const int64 timeout_micros)730 Status DebugGrpcChannel::Connect(const int64 timeout_micros) {
731   ::grpc::ChannelArguments args;
732   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
733   // Avoid problems where default reconnect backoff is too long (e.g., 20 s).
734   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
735   channel_ = ::grpc::CreateCustomChannel(
736       server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
737   if (!channel_->WaitForConnected(
738           gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
739                        gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) {
740     return errors::FailedPrecondition(
741         "Failed to connect to gRPC channel at ", server_stream_addr_,
742         " within a timeout of ", timeout_micros / 1e6, " s.");
743   }
744   stub_ = EventListener::NewStub(channel_);
745   reader_writer_ = stub_->SendEvents(&ctx_);
746 
747   return Status::OK();
748 }
749 
WriteEvent(const Event & event)750 bool DebugGrpcChannel::WriteEvent(const Event& event) {
751   mutex_lock l(mu_);
752   return reader_writer_->Write(event);
753 }
754 
ReadEventReply(EventReply * event_reply)755 bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
756   mutex_lock l(mu_);
757   return reader_writer_->Read(event_reply);
758 }
759 
ReceiveAndProcessEventReplies(const size_t max_replies)760 void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
761   EventReply event_reply;
762   size_t num_replies = 0;
763   while ((max_replies == 0 || ++num_replies <= max_replies) &&
764          ReadEventReply(&event_reply)) {
765     for (const EventReply::DebugOpStateChange& debug_op_state_change :
766          event_reply.debug_op_state_changes()) {
767       string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
768                                          debug_op_state_change.output_slot(),
769                                          ":", debug_op_state_change.debug_op());
770       DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
771                                             debug_op_state_change.state());
772     }
773   }
774 }
775 
ReceiveServerRepliesAndClose()776 Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
777   reader_writer_->WritesDone();
778   // Read all EventReply messages (if any) from the server.
779   ReceiveAndProcessEventReplies(0);
780 
781   if (reader_writer_->Finish().ok()) {
782     return Status::OK();
783   } else {
784     return Status(error::FAILED_PRECONDITION,
785                   "Failed to close debug GRPC stream.");
786   }
787 }
788 
789 mutex DebugGrpcIO::streams_mu(LINKER_INITIALIZED);
790 
791 int64 DebugGrpcIO::channel_connection_timeout_micros = 900 * 1000 * 1000;
792 // TODO(cais): Make this configurable?
793 
794 const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
795 
796 const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
797 
798 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
GetStreamChannels()799 DebugGrpcIO::GetStreamChannels() {
800   static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
801       stream_channels =
802           new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
803   return stream_channels;
804 }
805 
SendTensorThroughGrpcStream(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & grpc_stream_url,const bool gated)806 Status DebugGrpcIO::SendTensorThroughGrpcStream(
807     const DebugNodeKey& debug_node_key, const Tensor& tensor,
808     const uint64 wall_time_us, const string& grpc_stream_url,
809     const bool gated) {
810   if (gated &&
811       !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
812     return Status::OK();
813   } else {
814     std::vector<Event> events;
815     TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us,
816                                           kGrpcMessageSizeLimitBytes, &events));
817     for (const Event& event : events) {
818       TF_RETURN_IF_ERROR(
819           SendEventProtoThroughGrpcStream(event, grpc_stream_url));
820     }
821     if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
822       DebugGrpcChannel* debug_grpc_channel = nullptr;
823       TF_RETURN_IF_ERROR(
824           GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
825       debug_grpc_channel->ReceiveAndProcessEventReplies(1);
826       // TODO(cais): Support new tensor value carried in the EventReply for
827       // overriding the value of the tensor being published.
828     }
829     return Status::OK();
830   }
831 }
832 
ReceiveEventReplyProtoThroughGrpcStream(EventReply * event_reply,const string & grpc_stream_url)833 Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
834     EventReply* event_reply, const string& grpc_stream_url) {
835   DebugGrpcChannel* debug_grpc_channel = nullptr;
836   TF_RETURN_IF_ERROR(
837       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
838   if (debug_grpc_channel->ReadEventReply(event_reply)) {
839     return Status::OK();
840   } else {
841     return errors::Cancelled(strings::StrCat(
842         "Reading EventReply from stream URL ", grpc_stream_url, " failed."));
843   }
844 }
845 
GetOrCreateDebugGrpcChannel(const string & grpc_stream_url,DebugGrpcChannel ** debug_grpc_channel)846 Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
847     const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
848   const string addr_with_path =
849       grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
850           ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
851           : grpc_stream_url;
852   const string server_stream_addr =
853       addr_with_path.substr(0, addr_with_path.find('/'));
854   {
855     mutex_lock l(streams_mu);
856     std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
857         stream_channels = GetStreamChannels();
858     if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
859       std::unique_ptr<DebugGrpcChannel> channel(
860           new DebugGrpcChannel(server_stream_addr));
861       TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros));
862       stream_channels->insert(
863           std::make_pair(grpc_stream_url, std::move(channel)));
864     }
865     *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
866   }
867   return Status::OK();
868 }
869 
SendEventProtoThroughGrpcStream(const Event & event_proto,const string & grpc_stream_url,const bool receive_reply)870 Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
871     const Event& event_proto, const string& grpc_stream_url,
872     const bool receive_reply) {
873   DebugGrpcChannel* debug_grpc_channel;
874   TF_RETURN_IF_ERROR(
875       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
876 
877   bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
878   if (!write_ok) {
879     return errors::Cancelled(strings::StrCat("Write event to stream URL ",
880                                              grpc_stream_url, " failed."));
881   }
882 
883   if (receive_reply) {
884     debug_grpc_channel->ReceiveAndProcessEventReplies(1);
885   }
886 
887   return Status::OK();
888 }
889 
IsReadGateOpen(const string & grpc_debug_url,const string & watch_key)890 bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
891                                  const string& watch_key) {
892   const DebugNodeName2State* enabled_node_to_state =
893       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
894   return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
895 }
896 
IsWriteGateOpen(const string & grpc_debug_url,const string & watch_key)897 bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
898                                   const string& watch_key) {
899   const DebugNodeName2State* enabled_node_to_state =
900       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
901   auto it = enabled_node_to_state->find(watch_key);
902   if (it == enabled_node_to_state->end()) {
903     return false;
904   } else {
905     return it->second == EventReply::DebugOpStateChange::READ_WRITE;
906   }
907 }
908 
CloseGrpcStream(const string & grpc_stream_url)909 Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
910   mutex_lock l(streams_mu);
911 
912   std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
913       stream_channels = GetStreamChannels();
914   if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
915     // Stream of the specified address exists. Close it and remove it from
916     // record.
917     Status s =
918         (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
919     (*stream_channels).erase(grpc_stream_url);
920     return s;
921   } else {
922     // Stream of the specified address does not exist. No action.
923     return Status::OK();
924   }
925 }
926 
927 std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
GetEnabledDebugOpStates()928 DebugGrpcIO::GetEnabledDebugOpStates() {
929   static std::unordered_map<string, DebugNodeName2State>*
930       enabled_debug_op_states =
931           new std::unordered_map<string, DebugNodeName2State>();
932   return enabled_debug_op_states;
933 }
934 
GetEnabledDebugOpStatesAtUrl(const string & grpc_debug_url)935 DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
936     const string& grpc_debug_url) {
937   static mutex* debug_ops_state_mu = new mutex();
938   std::unordered_map<string, DebugNodeName2State>* states =
939       GetEnabledDebugOpStates();
940 
941   mutex_lock l(*debug_ops_state_mu);
942   if (states->find(grpc_debug_url) == states->end()) {
943     DebugNodeName2State url_enabled_debug_op_states;
944     (*states)[grpc_debug_url] = url_enabled_debug_op_states;
945   }
946   return &(*states)[grpc_debug_url];
947 }
948 
SetDebugNodeKeyGrpcState(const string & grpc_debug_url,const string & watch_key,const EventReply::DebugOpStateChange::State new_state)949 void DebugGrpcIO::SetDebugNodeKeyGrpcState(
950     const string& grpc_debug_url, const string& watch_key,
951     const EventReply::DebugOpStateChange::State new_state) {
952   DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
953   if (new_state == EventReply::DebugOpStateChange::DISABLED) {
954     if (states->find(watch_key) == states->end()) {
955       LOG(ERROR) << "Attempt to disable a watch key that is not currently "
956                  << "enabled at " << grpc_debug_url << ": " << watch_key;
957     } else {
958       states->erase(watch_key);
959     }
960   } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
961     (*states)[watch_key] = new_state;
962   }
963 }
964 
ClearEnabledWatchKeys()965 void DebugGrpcIO::ClearEnabledWatchKeys() {
966   GetEnabledDebugOpStates()->clear();
967 }
968 
969 #endif  // #ifndef PLATFORM_WINDOWS
970 
971 }  // namespace tensorflow
972