1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/debug/debug_io_utils.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 #include <cmath>
21 #include <cstdlib>
22 #include <cstring>
23 #include <limits>
24 #include <utility>
25 #include <vector>
26 
27 #ifndef PLATFORM_WINDOWS
28 #include "grpcpp/create_channel.h"
29 #else
30 #endif  // #ifndef PLATFORM_WINDOWS
31 
32 #include "absl/strings/ascii.h"
33 #include "absl/strings/match.h"
34 #include "tensorflow/core/debug/debug_callback_registry.h"
35 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/summary.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/lib/core/bits.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/protobuf.h"
46 #include "tensorflow/core/util/event.pb.h"
47 
48 #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
49   return errors::Unimplemented(              \
50       kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
51 
52 namespace tensorflow {
53 
54 namespace {
55 
56 // Creates an Event proto representing a chunk of a Tensor. This method only
57 // populates the field of the Event proto that represent the envelope
58 // information (e.g., timestamp, device_name, num_chunks, chunk_index, dtype,
59 // shape). It does not set the value.tensor field, which should be set by the
60 // caller separately.
PrepareChunkEventProto(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t num_chunks,const size_t chunk_index,const DataType & tensor_dtype,const TensorShapeProto & tensor_shape)61 Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key,
62                              const uint64 wall_time_us, const size_t num_chunks,
63                              const size_t chunk_index,
64                              const DataType& tensor_dtype,
65                              const TensorShapeProto& tensor_shape) {
66   Event event;
67   event.set_wall_time(static_cast<double>(wall_time_us));
68   Summary::Value* value = event.mutable_summary()->add_value();
69 
70   // Create the debug node_name in the Summary proto.
71   // For example, if tensor_name = "foo/node_a:0", and the debug_op is
72   // "DebugIdentity", the debug node_name in the Summary proto will be
73   // "foo/node_a:0:DebugIdentity".
74   value->set_node_name(debug_node_key.debug_node_name);
75 
76   // Tag by the node name. This allows TensorBoard to quickly fetch data
77   // per op.
78   value->set_tag(debug_node_key.node_name);
79 
80   // Store data within debugger metadata to be stored for each event.
81   third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
82   metadata.set_device(debug_node_key.device_name);
83   metadata.set_output_slot(debug_node_key.output_slot);
84   metadata.set_num_chunks(num_chunks);
85   metadata.set_chunk_index(chunk_index);
86 
87   // Encode the data in JSON.
88   string json_output;
89   tensorflow::protobuf::util::JsonPrintOptions json_options;
90   json_options.always_print_primitive_fields = true;
91   auto status = tensorflow::protobuf::util::MessageToJsonString(
92       metadata, &json_output, json_options);
93   if (status.ok()) {
94     // Store summary metadata. Set the plugin to use this data as "debugger".
95     SummaryMetadata::PluginData* plugin_data =
96         value->mutable_metadata()->mutable_plugin_data();
97     plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName);
98     plugin_data->set_content(json_output);
99   } else {
100     LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. "
101                  << "The debug_node_name is " << debug_node_key.debug_node_name
102                  << ".";
103   }
104 
105   value->mutable_tensor()->set_dtype(tensor_dtype);
106   *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape;
107 
108   return event;
109 }
110 
111 // Translates the length of a string to number of bytes when the string is
112 // encoded as bytes in protobuf. Note that this makes a conservative estimate
113 // (i.e., an estimate that is usually too large, but never too small under the
114 // gRPC message size limit) of the Varint-encoded length, to workaround the lack
115 // of a portable length function.
StringValMaxBytesInProto(const string & str)116 const size_t StringValMaxBytesInProto(const string& str) {
117 #if defined(PLATFORM_GOOGLE)
118   return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize;
119 #else
120   return str.size();
121 #endif
122 }
123 
124 // Breaks a string Tensor (represented as a TensorProto) as a vector of Event
125 // protos.
WrapStringTensorAsEvents(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t chunk_size_limit,TensorProto * tensor_proto,std::vector<Event> * events)126 Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key,
127                                 const uint64 wall_time_us,
128                                 const size_t chunk_size_limit,
129                                 TensorProto* tensor_proto,
130                                 std::vector<Event>* events) {
131   const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val();
132   const size_t num_strs = strs.size();
133   const size_t chunk_size_ub = chunk_size_limit > 0
134                                    ? chunk_size_limit
135                                    : std::numeric_limits<size_t>::max();
136 
137   // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges:
138   //   [0:a), [a:b), [c:<end>].
139   std::vector<size_t> cutoffs;
140   size_t chunk_size = 0;
141   for (size_t i = 0; i < num_strs; ++i) {
142     // Take into account the extra bytes in proto buffer.
143     if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
144       return errors::FailedPrecondition(
145           "string value at index ", i, " from debug node ",
146           debug_node_key.debug_node_name,
147           " does not fit gRPC message size limit (", chunk_size_ub, ")");
148     }
149     if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
150       cutoffs.push_back(i);
151       chunk_size = 0;
152     }
153     chunk_size += StringValMaxBytesInProto(strs[i]);
154   }
155   cutoffs.push_back(num_strs);
156   const size_t num_chunks = cutoffs.size();
157 
158   for (size_t i = 0; i < num_chunks; ++i) {
159     Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
160                                          num_chunks, i, tensor_proto->dtype(),
161                                          tensor_proto->tensor_shape());
162     Summary::Value* value = event.mutable_summary()->mutable_value(0);
163 
164     if (cutoffs.size() == 1) {
165       value->mutable_tensor()->mutable_string_val()->Swap(
166           tensor_proto->mutable_string_val());
167     } else {
168       const size_t begin = (i == 0) ? 0 : cutoffs[i - 1];
169       const size_t end = cutoffs[i];
170       for (size_t j = begin; j < end; ++j) {
171         value->mutable_tensor()->add_string_val(strs[j]);
172       }
173     }
174 
175     events->push_back(std::move(event));
176   }
177 
178   return Status::OK();
179 }
180 
181 // Encapsulates the tensor value inside a vector of Event protos. Large tensors
182 // are broken up to multiple protos to fit the chunk_size_limit. In each Event
183 // proto the field summary.tensor carries the content of the tensor.
184 // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a
185 // length-1 vector will be returned, regardless of the size of the tensor.
WrapTensorAsEvents(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const size_t chunk_size_limit,std::vector<Event> * events)186 Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key,
187                           const Tensor& tensor, const uint64 wall_time_us,
188                           const size_t chunk_size_limit,
189                           std::vector<Event>* events) {
190   TensorProto tensor_proto;
191   if (tensor.dtype() == DT_STRING) {
192     // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can
193     // convert the TensorProto to string-type numpy array. MakeNdarray does not
194     // work with strings encoded by AsProtoTensorContent() in tensor_content.
195     tensor.AsProtoField(&tensor_proto);
196 
197     TF_RETURN_IF_ERROR(WrapStringTensorAsEvents(
198         debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events));
199   } else {
200     tensor.AsProtoTensorContent(&tensor_proto);
201 
202     const size_t total_length = tensor_proto.tensor_content().size();
203     const size_t chunk_size_ub =
204         chunk_size_limit > 0 ? chunk_size_limit : total_length;
205     const size_t num_chunks =
206         (total_length == 0)
207             ? 1
208             : (total_length + chunk_size_ub - 1) / chunk_size_ub;
209     for (size_t i = 0; i < num_chunks; ++i) {
210       const size_t pos = i * chunk_size_ub;
211       const size_t len =
212           (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub;
213       Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
214                                            num_chunks, i, tensor_proto.dtype(),
215                                            tensor_proto.tensor_shape());
216       event.mutable_summary()
217           ->mutable_value(0)
218           ->mutable_tensor()
219           ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len));
220       events->push_back(std::move(event));
221     }
222   }
223 
224   return Status::OK();
225 }
226 
227 // Appends an underscore and a timestamp to a file path. If the path already
228 // exists on the file system, append a hyphen and a 1-up index. Consecutive
229 // values of the index will be tried until the first unused one is found.
230 // TOCTOU race condition is not of concern here due to the fact that tfdbg
231 // sets parallel_iterations attribute of all while_loops to 1 to prevent
232 // the same node from between executed multiple times concurrently.
AppendTimestampToFilePath(const string & in,const uint64 timestamp)233 string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
234   string out = strings::StrCat(in, "_", timestamp);
235 
236   uint64 i = 1;
237   while (Env::Default()->FileExists(out).ok()) {
238     out = strings::StrCat(in, "_", timestamp, "-", i);
239     ++i;
240   }
241   return out;
242 }
243 
244 #ifndef PLATFORM_WINDOWS
245 // Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
246 // conforming to the gRPC message size limit.
PublishEncodedGraphDefInChunks(const string & encoded_graph_def,const string & device_name,const int64 wall_time,const string & debug_url)247 Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
248                                       const string& device_name,
249                                       const int64 wall_time,
250                                       const string& debug_url) {
251   const uint64 hash = ::tensorflow::Hash64(encoded_graph_def);
252   const size_t total_length = encoded_graph_def.size();
253   const size_t num_chunks =
254       static_cast<size_t>(std::ceil(static_cast<float>(total_length) /
255                                     DebugGrpcIO::kGrpcMessageSizeLimitBytes));
256   for (size_t i = 0; i < num_chunks; ++i) {
257     const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes;
258     const size_t len = (i == num_chunks - 1)
259                            ? (total_length - pos)
260                            : DebugGrpcIO::kGrpcMessageSizeLimitBytes;
261     Event event;
262     event.set_wall_time(static_cast<double>(wall_time));
263     // Prefix the chunk with
264     //   <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|.
265     // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks
266     // and chunk_index, instead.
267     event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
268                                         "|", i, "|", num_chunks, "|",
269                                         encoded_graph_def.substr(pos, len)));
270     const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
271         event, debug_url, num_chunks - 1 == i);
272     if (!s.ok()) {
273       return errors::FailedPrecondition(
274           "Failed to send chunk ", i, " of ", num_chunks,
275           " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
276           "due to: ", s.error_message());
277     }
278   }
279   return Status::OK();
280 }
281 #endif  // #ifndef PLATFORM_WINDOWS
282 
283 }  // namespace
284 
285 const char* const DebugIO::kDebuggerPluginName = "debugger";
286 
287 const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
288 
289 const char* const DebugIO::kGraphTag = "graph_";
290 
291 const char* const DebugIO::kHashTag = "hash";
292 
ReadEventFromFile(const string & dump_file_path,Event * event)293 Status ReadEventFromFile(const string& dump_file_path, Event* event) {
294   Env* env(Env::Default());
295 
296   string content;
297   uint64 file_size = 0;
298 
299   Status s = env->GetFileSize(dump_file_path, &file_size);
300   if (!s.ok()) {
301     return s;
302   }
303 
304   content.resize(file_size);
305 
306   std::unique_ptr<RandomAccessFile> file;
307   s = env->NewRandomAccessFile(dump_file_path, &file);
308   if (!s.ok()) {
309     return s;
310   }
311 
312   StringPiece result;
313   s = file->Read(0, file_size, &result, &(content)[0]);
314   if (!s.ok()) {
315     return s;
316   }
317 
318   event->ParseFromString(content);
319   return Status::OK();
320 }
321 
322 const char* const DebugIO::kFileURLScheme = "file://";
323 const char* const DebugIO::kGrpcURLScheme = "grpc://";
324 const char* const DebugIO::kMemoryURLScheme = "memcbk://";
325 
326 // Publishes debug metadata to a set of debug URLs.
PublishDebugMetadata(const int64 global_step,const int64 session_run_index,const int64 executor_step_index,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,const std::unordered_set<string> & debug_urls)327 Status DebugIO::PublishDebugMetadata(
328     const int64 global_step, const int64 session_run_index,
329     const int64 executor_step_index, const std::vector<string>& input_names,
330     const std::vector<string>& output_names,
331     const std::vector<string>& target_nodes,
332     const std::unordered_set<string>& debug_urls) {
333   std::ostringstream oss;
334 
335   // Construct a JSON string to carry the metadata.
336   oss << "{";
337   oss << "\"global_step\":" << global_step << ",";
338   oss << "\"session_run_index\":" << session_run_index << ",";
339   oss << "\"executor_step_index\":" << executor_step_index << ",";
340   oss << "\"input_names\":[";
341   for (size_t i = 0; i < input_names.size(); ++i) {
342     oss << "\"" << input_names[i] << "\"";
343     if (i < input_names.size() - 1) {
344       oss << ",";
345     }
346   }
347   oss << "],";
348   oss << "\"output_names\":[";
349   for (size_t i = 0; i < output_names.size(); ++i) {
350     oss << "\"" << output_names[i] << "\"";
351     if (i < output_names.size() - 1) {
352       oss << ",";
353     }
354   }
355   oss << "],";
356   oss << "\"target_nodes\":[";
357   for (size_t i = 0; i < target_nodes.size(); ++i) {
358     oss << "\"" << target_nodes[i] << "\"";
359     if (i < target_nodes.size() - 1) {
360       oss << ",";
361     }
362   }
363   oss << "]";
364   oss << "}";
365 
366   const string json_metadata = oss.str();
367   Event event;
368   event.set_wall_time(static_cast<double>(Env::Default()->NowMicros()));
369   LogMessage* log_message = event.mutable_log_message();
370   log_message->set_message(json_metadata);
371 
372   Status status;
373   for (const string& url : debug_urls) {
374     if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
375 #ifndef PLATFORM_WINDOWS
376       Event grpc_event;
377 
378       // Determine the path (if any) in the grpc:// URL, and add it as a field
379       // of the JSON string.
380       const string address = url.substr(strlen(DebugIO::kFileURLScheme));
381       const string path = address.find('/') == string::npos
382                               ? ""
383                               : address.substr(address.find('/'));
384       grpc_event.set_wall_time(event.wall_time());
385       LogMessage* log_message_grpc = grpc_event.mutable_log_message();
386       log_message_grpc->set_message(
387           strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1),
388                           ",\"grpc_path\":\"", path, "\"}"));
389 
390       status.Update(
391           DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
392 #else
393       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
394 #endif
395     } else if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
396       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
397       const string core_metadata_path = AppendTimestampToFilePath(
398           io::JoinPath(dump_root_dir,
399                        strings::StrCat(
400                            DebugNodeKey::kMetadataFilePrefix,
401                            DebugIO::kCoreMetadataTag, "sessionrun",
402                            strings::Printf("%.14lld", static_cast<long long>(
403                                                           session_run_index)))),
404           Env::Default()->NowMicros());
405       status.Update(DebugFileIO::DumpEventProtoToFile(
406           event, string(io::Dirname(core_metadata_path)),
407           string(io::Basename(core_metadata_path))));
408     }
409   }
410 
411   return status;
412 }
413 
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> debug_urls,const bool gated_grpc)414 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
415                                    const Tensor& tensor,
416                                    const uint64 wall_time_us,
417                                    const gtl::ArraySlice<string> debug_urls,
418                                    const bool gated_grpc) {
419   int32 num_failed_urls = 0;
420   std::vector<Status> fail_statuses;
421   for (const string& url : debug_urls) {
422     if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
423       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
424 
425       const int64 tensorBytes =
426           tensor.IsInitialized() ? tensor.TotalBytes() : 0;
427       if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
428         return errors::ResourceExhausted(
429             "TensorFlow Debugger has exhausted file-system byte-size "
430             "allowance (",
431             DebugFileIO::global_disk_bytes_limit_, "), therefore it cannot ",
432             "dump an additional ", tensorBytes, " byte(s) of tensor data ",
433             "for the debug tensor ", debug_node_key.node_name, ":",
434             debug_node_key.output_slot, ". You may use the environment ",
435             "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
436       }
437 
438       Status s = DebugFileIO::DumpTensorToDir(
439           debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
440       if (!s.ok()) {
441         num_failed_urls++;
442         fail_statuses.push_back(s);
443       }
444     } else if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
445 #ifndef PLATFORM_WINDOWS
446       Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
447           debug_node_key, tensor, wall_time_us, url, gated_grpc);
448 
449       if (!s.ok()) {
450         num_failed_urls++;
451         fail_statuses.push_back(s);
452       }
453 #else
454       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
455 #endif
456     } else if (absl::StartsWith(absl::AsciiStrToLower(url), kMemoryURLScheme)) {
457       const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
458       auto* callback_registry = DebugCallbackRegistry::singleton();
459       auto* callback = callback_registry->GetCallback(dump_root_dir);
460       CHECK(callback) << "No callback registered for: " << dump_root_dir;
461       (*callback)(debug_node_key, tensor);
462     } else {
463       return Status(error::UNAVAILABLE,
464                     strings::StrCat("Invalid debug target URL: ", url));
465     }
466   }
467 
468   if (num_failed_urls == 0) {
469     return Status::OK();
470   } else {
471     string error_message = strings::StrCat(
472         "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
473         " debug target URLs failed, due to the following errors:");
474     for (Status& status : fail_statuses) {
475       error_message =
476           strings::StrCat(error_message, " ", status.error_message(), ";");
477     }
478 
479     return Status(error::INTERNAL, error_message);
480   }
481 }
482 
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> debug_urls)483 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
484                                    const Tensor& tensor,
485                                    const uint64 wall_time_us,
486                                    const gtl::ArraySlice<string> debug_urls) {
487   return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls,
488                             false);
489 }
490 
PublishGraph(const Graph & graph,const string & device_name,const std::unordered_set<string> & debug_urls)491 Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
492                              const std::unordered_set<string>& debug_urls) {
493   GraphDef graph_def;
494   graph.ToGraphDef(&graph_def);
495 
496   string buf;
497   graph_def.SerializeToString(&buf);
498 
499   const int64 now_micros = Env::Default()->NowMicros();
500   Event event;
501   event.set_wall_time(static_cast<double>(now_micros));
502   event.set_graph_def(buf);
503 
504   Status status = Status::OK();
505   for (const string& debug_url : debug_urls) {
506     if (absl::StartsWith(debug_url, kFileURLScheme)) {
507       const string dump_root_dir =
508           io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
509                        DebugNodeKey::DeviceNameToDevicePath(device_name));
510       const uint64 graph_hash = ::tensorflow::Hash64(buf);
511       const string file_name =
512           strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
513                           DebugIO::kHashTag, graph_hash, "_", now_micros);
514 
515       status.Update(
516           DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
517     } else if (absl::StartsWith(debug_url, kGrpcURLScheme)) {
518 #ifndef PLATFORM_WINDOWS
519       status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
520                                                    debug_url));
521 #else
522       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
523 #endif
524     }
525   }
526 
527   return status;
528 }
529 
IsCopyNodeGateOpen(const std::vector<DebugWatchAndURLSpec> & specs)530 bool DebugIO::IsCopyNodeGateOpen(
531     const std::vector<DebugWatchAndURLSpec>& specs) {
532 #ifndef PLATFORM_WINDOWS
533   for (const DebugWatchAndURLSpec& spec : specs) {
534     if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
535                                              DebugIO::kGrpcURLScheme)) {
536       return true;
537     } else {
538       if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
539         return true;
540       }
541     }
542   }
543   return false;
544 #else
545   return true;
546 #endif
547 }
548 
IsDebugNodeGateOpen(const string & watch_key,const std::vector<string> & debug_urls)549 bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
550                                   const std::vector<string>& debug_urls) {
551 #ifndef PLATFORM_WINDOWS
552   for (const string& debug_url : debug_urls) {
553     if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
554                           DebugIO::kGrpcURLScheme)) {
555       return true;
556     } else {
557       if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
558         return true;
559       }
560     }
561   }
562   return false;
563 #else
564   return true;
565 #endif
566 }
567 
IsDebugURLGateOpen(const string & watch_key,const string & debug_url)568 bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
569                                  const string& debug_url) {
570 #ifndef PLATFORM_WINDOWS
571   if (debug_url != kGrpcURLScheme) {
572     return true;
573   } else {
574     return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
575   }
576 #else
577   return true;
578 #endif
579 }
580 
CloseDebugURL(const string & debug_url)581 Status DebugIO::CloseDebugURL(const string& debug_url) {
582   if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) {
583 #ifndef PLATFORM_WINDOWS
584     return DebugGrpcIO::CloseGrpcStream(debug_url);
585 #else
586     GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
587 #endif
588   } else {
589     // No-op for non-gRPC URLs.
590     return Status::OK();
591   }
592 }
593 
DumpTensorToDir(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & dump_root_dir,string * dump_file_path)594 Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
595                                     const Tensor& tensor,
596                                     const uint64 wall_time_us,
597                                     const string& dump_root_dir,
598                                     string* dump_file_path) {
599   const string file_path =
600       GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us);
601 
602   if (dump_file_path != nullptr) {
603     *dump_file_path = file_path;
604   }
605 
606   return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
607 }
608 
GetDumpFilePath(const string & dump_root_dir,const DebugNodeKey & debug_node_key,const uint64 wall_time_us)609 string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
610                                     const DebugNodeKey& debug_node_key,
611                                     const uint64 wall_time_us) {
612   return AppendTimestampToFilePath(
613       io::JoinPath(dump_root_dir, debug_node_key.device_path,
614                    strings::StrCat(debug_node_key.node_name, "_",
615                                    debug_node_key.output_slot, "_",
616                                    debug_node_key.debug_op)),
617       wall_time_us);
618 }
619 
DumpEventProtoToFile(const Event & event_proto,const string & dir_name,const string & file_name)620 Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
621                                          const string& dir_name,
622                                          const string& file_name) {
623   Env* env(Env::Default());
624 
625   Status s = RecursiveCreateDir(env, dir_name);
626   if (!s.ok()) {
627     return Status(error::FAILED_PRECONDITION,
628                   strings::StrCat("Failed to create directory  ", dir_name,
629                                   ", due to: ", s.error_message()));
630   }
631 
632   const string file_path = io::JoinPath(dir_name, file_name);
633 
634   string event_str;
635   event_proto.SerializeToString(&event_str);
636 
637   std::unique_ptr<WritableFile> f = nullptr;
638   TF_CHECK_OK(env->NewWritableFile(file_path, &f));
639   f->Append(event_str).IgnoreError();
640   TF_CHECK_OK(f->Close());
641 
642   return Status::OK();
643 }
644 
DumpTensorToEventFile(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & file_path)645 Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
646                                           const Tensor& tensor,
647                                           const uint64 wall_time_us,
648                                           const string& file_path) {
649   std::vector<Event> events;
650   TF_RETURN_IF_ERROR(
651       WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
652   return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
653                               string(io::Basename(file_path)));
654 }
655 
RecursiveCreateDir(Env * env,const string & dir)656 Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
657   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
658     // The path already exists as a directory. Return OK right away.
659     return Status::OK();
660   }
661 
662   string parent_dir(io::Dirname(dir));
663   if (!env->FileExists(parent_dir).ok()) {
664     // The parent path does not exist yet, create it first.
665     Status s = RecursiveCreateDir(env, parent_dir);  // Recursive call
666     if (!s.ok()) {
667       return Status(
668           error::FAILED_PRECONDITION,
669           strings::StrCat("Failed to create directory  ", parent_dir));
670     }
671   } else if (env->FileExists(parent_dir).ok() &&
672              !env->IsDirectory(parent_dir).ok()) {
673     // The path exists, but it is a file.
674     return Status(error::FAILED_PRECONDITION,
675                   strings::StrCat("Failed to create directory  ", parent_dir,
676                                   " because the path exists as a file "));
677   }
678 
679   env->CreateDir(dir).IgnoreError();
680   // Guard against potential race in creating directories by doing a check
681   // after the CreateDir call.
682   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
683     return Status::OK();
684   } else {
685     return Status(error::ABORTED,
686                   strings::StrCat("Failed to create directory  ", parent_dir));
687   }
688 }
689 
690 // Default total disk usage limit: 100 GBytes
691 const uint64 DebugFileIO::kDefaultGlobalDiskBytesLimit = 107374182400L;
692 uint64 DebugFileIO::global_disk_bytes_limit_ = 0;
693 uint64 DebugFileIO::disk_bytes_used_ = 0;
694 
695 mutex DebugFileIO::bytes_mu_(LINKER_INITIALIZED);
696 
requestDiskByteUsage(uint64 bytes)697 bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
698   mutex_lock l(bytes_mu_);
699   if (global_disk_bytes_limit_ == 0) {
700     const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
701     if (env_tfdbg_disk_bytes_limit == nullptr ||
702         strlen(env_tfdbg_disk_bytes_limit) == 0) {
703       global_disk_bytes_limit_ = kDefaultGlobalDiskBytesLimit;
704     } else {
705       strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
706                              &global_disk_bytes_limit_);
707     }
708   }
709 
710   if (bytes == 0) {
711     return true;
712   }
713   if (disk_bytes_used_ + bytes < global_disk_bytes_limit_) {
714     disk_bytes_used_ += bytes;
715     return true;
716   } else {
717     return false;
718   }
719 }
720 
resetDiskByteUsage()721 void DebugFileIO::resetDiskByteUsage() {
722   mutex_lock l(bytes_mu_);
723   disk_bytes_used_ = 0;
724 }
725 
726 #ifndef PLATFORM_WINDOWS
DebugGrpcChannel(const string & server_stream_addr)727 DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
728     : server_stream_addr_(server_stream_addr),
729       url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
730 
Connect(const int64 timeout_micros)731 Status DebugGrpcChannel::Connect(const int64 timeout_micros) {
732   ::grpc::ChannelArguments args;
733   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
734   // Avoid problems where default reconnect backoff is too long (e.g., 20 s).
735   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
736   channel_ = ::grpc::CreateCustomChannel(
737       server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
738   if (!channel_->WaitForConnected(
739           gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
740                        gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) {
741     return errors::FailedPrecondition(
742         "Failed to connect to gRPC channel at ", server_stream_addr_,
743         " within a timeout of ", timeout_micros / 1e6, " s.");
744   }
745   stub_ = EventListener::NewStub(channel_);
746   reader_writer_ = stub_->SendEvents(&ctx_);
747 
748   return Status::OK();
749 }
750 
WriteEvent(const Event & event)751 bool DebugGrpcChannel::WriteEvent(const Event& event) {
752   mutex_lock l(mu_);
753   return reader_writer_->Write(event);
754 }
755 
ReadEventReply(EventReply * event_reply)756 bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
757   mutex_lock l(mu_);
758   return reader_writer_->Read(event_reply);
759 }
760 
ReceiveAndProcessEventReplies(const size_t max_replies)761 void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
762   EventReply event_reply;
763   size_t num_replies = 0;
764   while ((max_replies == 0 || ++num_replies <= max_replies) &&
765          ReadEventReply(&event_reply)) {
766     for (const EventReply::DebugOpStateChange& debug_op_state_change :
767          event_reply.debug_op_state_changes()) {
768       string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
769                                          debug_op_state_change.output_slot(),
770                                          ":", debug_op_state_change.debug_op());
771       DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
772                                             debug_op_state_change.state());
773     }
774   }
775 }
776 
ReceiveServerRepliesAndClose()777 Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
778   reader_writer_->WritesDone();
779   // Read all EventReply messages (if any) from the server.
780   ReceiveAndProcessEventReplies(0);
781 
782   if (reader_writer_->Finish().ok()) {
783     return Status::OK();
784   } else {
785     return Status(error::FAILED_PRECONDITION,
786                   "Failed to close debug GRPC stream.");
787   }
788 }
789 
790 mutex DebugGrpcIO::streams_mu_(LINKER_INITIALIZED);
791 
792 int64 DebugGrpcIO::channel_connection_timeout_micros_ = 900 * 1000 * 1000;
793 // TODO(cais): Make this configurable?
794 
795 const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
796 
797 const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
798 
799 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
GetStreamChannels()800 DebugGrpcIO::GetStreamChannels() {
801   static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
802       stream_channels =
803           new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
804   return stream_channels;
805 }
806 
SendTensorThroughGrpcStream(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & grpc_stream_url,const bool gated)807 Status DebugGrpcIO::SendTensorThroughGrpcStream(
808     const DebugNodeKey& debug_node_key, const Tensor& tensor,
809     const uint64 wall_time_us, const string& grpc_stream_url,
810     const bool gated) {
811   if (gated &&
812       !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
813     return Status::OK();
814   } else {
815     std::vector<Event> events;
816     TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us,
817                                           kGrpcMessageSizeLimitBytes, &events));
818     for (const Event& event : events) {
819       TF_RETURN_IF_ERROR(
820           SendEventProtoThroughGrpcStream(event, grpc_stream_url));
821     }
822     if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
823       DebugGrpcChannel* debug_grpc_channel = nullptr;
824       TF_RETURN_IF_ERROR(
825           GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
826       debug_grpc_channel->ReceiveAndProcessEventReplies(1);
827       // TODO(cais): Support new tensor value carried in the EventReply for
828       // overriding the value of the tensor being published.
829     }
830     return Status::OK();
831   }
832 }
833 
ReceiveEventReplyProtoThroughGrpcStream(EventReply * event_reply,const string & grpc_stream_url)834 Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
835     EventReply* event_reply, const string& grpc_stream_url) {
836   DebugGrpcChannel* debug_grpc_channel = nullptr;
837   TF_RETURN_IF_ERROR(
838       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
839   if (debug_grpc_channel->ReadEventReply(event_reply)) {
840     return Status::OK();
841   } else {
842     return errors::Cancelled(strings::StrCat(
843         "Reading EventReply from stream URL ", grpc_stream_url, " failed."));
844   }
845 }
846 
GetOrCreateDebugGrpcChannel(const string & grpc_stream_url,DebugGrpcChannel ** debug_grpc_channel)847 Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
848     const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
849   const string addr_with_path =
850       absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme)
851           ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
852           : grpc_stream_url;
853   const string server_stream_addr =
854       addr_with_path.substr(0, addr_with_path.find('/'));
855   {
856     mutex_lock l(streams_mu_);
857     std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
858         stream_channels = GetStreamChannels();
859     if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
860       std::unique_ptr<DebugGrpcChannel> channel(
861           new DebugGrpcChannel(server_stream_addr));
862       TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros_));
863       stream_channels->insert(
864           std::make_pair(grpc_stream_url, std::move(channel)));
865     }
866     *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
867   }
868   return Status::OK();
869 }
870 
SendEventProtoThroughGrpcStream(const Event & event_proto,const string & grpc_stream_url,const bool receive_reply)871 Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
872     const Event& event_proto, const string& grpc_stream_url,
873     const bool receive_reply) {
874   DebugGrpcChannel* debug_grpc_channel;
875   TF_RETURN_IF_ERROR(
876       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
877 
878   bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
879   if (!write_ok) {
880     return errors::Cancelled(strings::StrCat("Write event to stream URL ",
881                                              grpc_stream_url, " failed."));
882   }
883 
884   if (receive_reply) {
885     debug_grpc_channel->ReceiveAndProcessEventReplies(1);
886   }
887 
888   return Status::OK();
889 }
890 
IsReadGateOpen(const string & grpc_debug_url,const string & watch_key)891 bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
892                                  const string& watch_key) {
893   const DebugNodeName2State* enabled_node_to_state =
894       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
895   return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
896 }
897 
IsWriteGateOpen(const string & grpc_debug_url,const string & watch_key)898 bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
899                                   const string& watch_key) {
900   const DebugNodeName2State* enabled_node_to_state =
901       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
902   auto it = enabled_node_to_state->find(watch_key);
903   if (it == enabled_node_to_state->end()) {
904     return false;
905   } else {
906     return it->second == EventReply::DebugOpStateChange::READ_WRITE;
907   }
908 }
909 
CloseGrpcStream(const string & grpc_stream_url)910 Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
911   mutex_lock l(streams_mu_);
912 
913   std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
914       stream_channels = GetStreamChannels();
915   if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
916     // Stream of the specified address exists. Close it and remove it from
917     // record.
918     Status s =
919         (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
920     (*stream_channels).erase(grpc_stream_url);
921     return s;
922   } else {
923     // Stream of the specified address does not exist. No action.
924     return Status::OK();
925   }
926 }
927 
928 std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
GetEnabledDebugOpStates()929 DebugGrpcIO::GetEnabledDebugOpStates() {
930   static std::unordered_map<string, DebugNodeName2State>*
931       enabled_debug_op_states =
932           new std::unordered_map<string, DebugNodeName2State>();
933   return enabled_debug_op_states;
934 }
935 
GetEnabledDebugOpStatesAtUrl(const string & grpc_debug_url)936 DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
937     const string& grpc_debug_url) {
938   static mutex* debug_ops_state_mu = new mutex();
939   std::unordered_map<string, DebugNodeName2State>* states =
940       GetEnabledDebugOpStates();
941 
942   mutex_lock l(*debug_ops_state_mu);
943   if (states->find(grpc_debug_url) == states->end()) {
944     DebugNodeName2State url_enabled_debug_op_states;
945     (*states)[grpc_debug_url] = url_enabled_debug_op_states;
946   }
947   return &(*states)[grpc_debug_url];
948 }
949 
SetDebugNodeKeyGrpcState(const string & grpc_debug_url,const string & watch_key,const EventReply::DebugOpStateChange::State new_state)950 void DebugGrpcIO::SetDebugNodeKeyGrpcState(
951     const string& grpc_debug_url, const string& watch_key,
952     const EventReply::DebugOpStateChange::State new_state) {
953   DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
954   if (new_state == EventReply::DebugOpStateChange::DISABLED) {
955     if (states->find(watch_key) == states->end()) {
956       LOG(ERROR) << "Attempt to disable a watch key that is not currently "
957                  << "enabled at " << grpc_debug_url << ": " << watch_key;
958     } else {
959       states->erase(watch_key);
960     }
961   } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
962     (*states)[watch_key] = new_state;
963   }
964 }
965 
ClearEnabledWatchKeys()966 void DebugGrpcIO::ClearEnabledWatchKeys() {
967   GetEnabledDebugOpStates()->clear();
968 }
969 
970 #endif  // #ifndef PLATFORM_WINDOWS
971 
972 }  // namespace tensorflow
973