1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/strings/strip.h"
27 #include "tensorflow/core/platform/regexp.h"
28 
29 namespace tensorflow {
30 namespace profiler {
31 namespace {
32 
33 const absl::string_view kIterator = "Iterator";
34 const absl::string_view kSeparator = "::";
35 constexpr char kNameScopeSeparator = '/';
36 
37 }  // namespace
38 
39 const absl::string_view kUnknownOp = "";  // op types are non-empty strings
40 const absl::string_view kDatasetOp = "Dataset";
41 const absl::string_view kMemcpyHToDOp = "MemcpyHToD";
42 const absl::string_view kMemcpyDToHOp = "MemcpyDToH";
43 
IsTfOpName(absl::string_view op_name)44 bool IsTfOpName(absl::string_view op_name) {
45   // TODO(b/177602927): Confirm the naming convention with the TF team.
46   static const LazyRE2 kTfOpNameRegEx = {"[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*"};
47   return RE2::FullMatch(op_name, *kTfOpNameRegEx);
48 }
49 
IsTfOpType(absl::string_view op_type)50 bool IsTfOpType(absl::string_view op_type) {
51   static const LazyRE2 kTfOpTypeRegEx = {"[A-Z_][a-zA-Z0-9_]*"};
52   return RE2::FullMatch(op_type, *kTfOpTypeRegEx);
53 }
54 
IsJaxOpType(absl::string_view op_type)55 bool IsJaxOpType(absl::string_view op_type) {
56   static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_][a-z0-9_]*"};
57   return RE2::FullMatch(op_type, *kJaxOpTypeRegEx);
58 }
59 
IsJaxOpNameAndType(absl::string_view op_name,absl::string_view op_type)60 bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type) {
61   if (op_name.empty() || !IsJaxOpType(op_type)) return false;
62   std::vector<absl::string_view> split_result =
63       absl::StrSplit(op_name, kNameScopeSeparator);
64   return absl::StrContains(split_result.back(), op_type);
65 }
66 
ParseTfOpFullname(absl::string_view tf_op_fullname)67 TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
68   // TF Op names have the format "name:type".
69   TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
70   std::vector<absl::string_view> parts =
71       absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
72   if (parts.size() != 2) {
73     // GPU-related Ops that need to be tracked.
74     if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
75       tf_op.category = Category::kMemcpyHToD;
76       tf_op.type = kMemcpyHToDOp;
77     } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
78       tf_op.category = Category::kMemcpyDToH;
79       tf_op.type = kMemcpyDToHOp;
80     }
81     // TODO(ckluk): Include the corresponding Ops on TPU.
82   } else if (parts[0] == kIterator) {
83     // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
84     // format of TF Op names. But we still want to capture them for
85     // input-pipeline analysis.
86     tf_op.category = Category::kTfData;
87     tf_op.type = kDatasetOp;
88   } else if (IsTfOpType(parts[1]) && IsTfOpName(parts[0])) {
89     tf_op = {Category::kTensorFlow, parts[0], parts[1]};
90   } else if (IsJaxOpType(parts[1])) {
91     tf_op = {Category::kJax, parts[0], parts[1]};
92   } else if (parts[1].empty()) {
93     tf_op.name = parts[0];  // remove trailing ':'
94   }
95   return tf_op;
96 }
97 
ParseTfNameScopes(const TfOp & tf_op)98 std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
99   std::vector<absl::string_view> name_scopes =
100       absl::StrSplit(tf_op.name, kNameScopeSeparator);
101   // The last element is an op name not TF name scope.
102   if (!name_scopes.empty()) name_scopes.pop_back();
103   return name_scopes;
104 }
105 
TfOpEventName(const TfOp & tf_op)106 std::string TfOpEventName(const TfOp& tf_op) {
107   std::string event_name;
108   if (tf_op.category == Category::kUnknown) {
109     // Some TraceMe names contain trailing whitespace, remove it.
110     event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
111   } else if (tf_op.category == Category::kTfData) {
112     event_name = DatasetOpEventName(tf_op.name);
113   } else {
114     event_name = std::string(tf_op.type);
115   }
116   return event_name;
117 }
118 
TfOpEventName(absl::string_view tf_op_fullname)119 std::string TfOpEventName(absl::string_view tf_op_fullname) {
120   return TfOpEventName(ParseTfOpFullname(tf_op_fullname));
121 }
122 
DatasetOpEventName(absl::string_view full_name)123 std::string DatasetOpEventName(absl::string_view full_name) {
124   std::vector<absl::string_view> split_result =
125       absl::StrSplit(full_name, kSeparator);
126   return absl::StrCat(kIterator, kSeparator, split_result.back());
127 }
128 
IteratorName(absl::string_view full_name)129 std::string IteratorName(absl::string_view full_name) {
130   std::vector<absl::string_view> split_result =
131       absl::StrSplit(full_name, kSeparator);
132   return std::string(split_result.back());
133 }
134 
ParseTensorShapes(absl::string_view tensor_shapes)135 std::vector<absl::string_view> ParseTensorShapes(
136     absl::string_view tensor_shapes) {
137   absl::ConsumePrefix(&tensor_shapes, "(");
138   absl::ConsumeSuffix(&tensor_shapes, ")");
139   return absl::StrSplit(tensor_shapes, ';');
140 }
141 
142 }  // namespace profiler
143 }  // namespace tensorflow
144