1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
17 #define TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/match.h"
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/core/platform/macros.h"
25 
26 namespace tensorflow {
27 namespace profiler {
28 
29 // Special op types.
30 TF_CONST_INIT extern const absl::string_view kUnknownOp;
31 TF_CONST_INIT extern const absl::string_view kDatasetOp;
32 TF_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
33 TF_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
34 
35 enum class Category {
36   kTensorFlow,
37   kJax,
38   kTfData,
39   kMemcpyHToD,
40   kMemcpyDToH,
41   kUnknown,
42 };
43 
44 // Breaks a TensorFlow op fullname into name and type.
45 struct TfOp {
46   Category category;
47   absl::string_view name;
48   absl::string_view type;
49 };
50 TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
51 
52 // Returns a vector of TF name scopes extracted from tf_op_full_name.
53 std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
54 
55 // Trace event name for TF ops is the op type so they have the same color in
56 // trace viewer.
57 std::string TfOpEventName(const TfOp& tf_op);
58 std::string TfOpEventName(absl::string_view tf_op_fullname);
59 
60 // Trace event name for dataset ops.
61 std::string DatasetOpEventName(absl::string_view full_name);
62 
63 // Returns the iterator name without prefix and parent iterator names.
64 std::string IteratorName(absl::string_view full_name);
65 
66 // Returns true if the given name is a TensorFlow Dataset Op.
IsDatasetOp(absl::string_view tf_op_type)67 inline bool IsDatasetOp(absl::string_view tf_op_type) {
68   return tf_op_type == kDatasetOp;
69 }
IsDatasetOp(const TfOp & tf_op)70 inline bool IsDatasetOp(const TfOp& tf_op) {
71   return tf_op.category == Category::kTfData;
72 }
73 
74 // Returns true if the given name is a TensorFlow Infeed Enqueue Op.
IsInfeedEnqueueOp(absl::string_view tf_op_type)75 inline bool IsInfeedEnqueueOp(absl::string_view tf_op_type) {
76   return tf_op_type == "InfeedEnqueue" || tf_op_type == "InfeedEnqueueTuple";
77 }
78 
79 // Returns true if the given op is for outside compilation.
IsOutsideCompilationOp(absl::string_view tf_op_fullname,absl::string_view hlo_expression)80 inline bool IsOutsideCompilationOp(absl::string_view tf_op_fullname,
81                                    absl::string_view hlo_expression) {
82   if (absl::EndsWith(tf_op_fullname, ":XlaSendToHost")) return true;
83   if (absl::StrContains(hlo_expression, "send-done") &&
84       absl::StrContains(hlo_expression, "is_host_transfer=true"))
85     return true;
86   return false;
87 }
88 
89 // Returns true if the given name is a TensorFlow embedding op.
IsEmbeddingOp(absl::string_view tf_op_fullname)90 inline bool IsEmbeddingOp(absl::string_view tf_op_fullname) {
91   return absl::StrContains(tf_op_fullname, "Embedding");
92 }
93 
94 // Returns true if the given op is for copying data from host to device.
IsMemcpyHToDOp(absl::string_view tf_op_type)95 inline bool IsMemcpyHToDOp(absl::string_view tf_op_type) {
96   return tf_op_type == kMemcpyHToDOp;
97 }
98 
99 // Returns true if the given op is for copying data from device to host.
IsMemcpyDToHOp(absl::string_view tf_op_type)100 inline bool IsMemcpyDToHOp(absl::string_view tf_op_type) {
101   return tf_op_type == kMemcpyDToHOp;
102 }
103 
104 // Splits a string of tensor shapes in "(shape1;shape2;...)" format, i.e.,
105 // delimited by '(' and ')' and separated by ';', into the individual shapes.
106 std::vector<absl::string_view> ParseTensorShapes(
107     absl::string_view tensor_shapes);
108 
109 // Returns true if the given string matches OpDef.name pattern.
110 bool IsTfOpName(absl::string_view op_name);
111 
112 // Returns true if the given string matches NodeDef.name pattern.
113 bool IsTfOpType(absl::string_view op_type);
114 
115 // Returns true if the given string matches JAX pattern.
116 bool IsJaxOpType(absl::string_view op_type);
117 
118 // Returns true if the given strings match JAX pattern.
119 bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type);
120 
121 }  // namespace profiler
122 }  // namespace tensorflow
123 
124 #endif  // TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
125