1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_
17 #define TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_
18
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/lib/random/philox_random.h"
26 #include "tensorflow/core/lib/random/simple_philox.h"
27
28 namespace tensorflow {
29 namespace test {
30
31 REGISTER_OP("Input").Output("y: float");
32 REGISTER_OP("Output")
33 .Input("x: N * float")
34 .Attr("N: int >= 1")
35 .Output("y: float");
36 REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("y: float");
37 REGISTER_OP("In4Out1")
38 .Input("a: float")
39 .Input("b: float")
40 .Input("c: float")
41 .Input("d: float")
42 .Output("y: float");
43 REGISTER_OP("In8Out1")
44 .Input("a: float")
45 .Input("b: float")
46 .Input("c: float")
47 .Input("d: float")
48 .Input("e: float")
49 .Input("f: float")
50 .Input("g: float")
51 .Input("h: float")
52 .Output("y: float");
53 REGISTER_OP("In16Out1")
54 .Input("a: float")
55 .Input("b: float")
56 .Input("c: float")
57 .Input("d: float")
58 .Input("e: float")
59 .Input("f: float")
60 .Input("g: float")
61 .Input("h: float")
62 .Input("i: float")
63 .Input("j: float")
64 .Input("k: float")
65 .Input("l: float")
66 .Input("m: float")
67 .Input("n: float")
68 .Input("o: float")
69 .Input("p: float")
70 .Output("y: float");
71
CreateGraphDef(int num_nodes,int num_edges_per_node)72 GraphDef CreateGraphDef(int num_nodes, int num_edges_per_node) {
73 const int kNumInNodes = 10 * num_edges_per_node;
74 GraphDef graph_def;
75
76 auto create_node = [](const string& name, const string& op) {
77 NodeDef node;
78 node.set_name(name);
79 node.set_op(op);
80 return node;
81 };
82
83 NodeDef node;
84 for (int in = 0; in < kNumInNodes; ++in) {
85 node = create_node(/*name=*/absl::StrFormat("in%04d", in), /*op=*/"Input");
86 *graph_def.add_node() = std::move(node);
87 }
88
89 random::PhiloxRandom philox(301, 17);
90 random::SimplePhilox rnd(&philox);
91 for (int op = 0; op < num_nodes; ++op) {
92 node = create_node(/*name=*/absl::StrFormat("op%05d", op),
93 /*op=*/absl::StrFormat("In%dOut1", num_edges_per_node));
94 for (int edge = 0; edge < num_edges_per_node; ++edge) {
95 node.add_input(absl::StrFormat("in%04d", rnd.Uniform(kNumInNodes)));
96 }
97 *graph_def.add_node() = std::move(node);
98 }
99
100 // Add a single sink node. Otherwise a lot of time is spent in
101 // FixupSourceAndSinkEdges().
102 node = create_node(/*name=*/"out", /*op=*/"Output");
103 for (int op = 0; op < num_nodes; ++op) {
104 node.add_input(absl::StrFormat("op%05d", op));
105 }
106 AttrValue attr;
107 attr.set_i(num_nodes);
108 node.mutable_attr()->insert({"N", std::move(attr)});
109 *graph_def.add_node() = std::move(node);
110
111 return graph_def;
112 }
113
CreateRandomGraph(int size)114 GraphDef CreateRandomGraph(int size) {
115 random::PhiloxRandom philox(0x12345);
116 random::SimplePhilox rnd(&philox);
117
118 string prefix = "long_node_name_prefix_to_measure_string_copy_overhead";
119
120 GraphDef graph;
121 for (int i = 0; i < size; ++i) {
122 const string name = absl::StrCat(prefix, i);
123 const uint32 num_inputs = rnd.Uniform(std::min(i, 5));
124
125 NodeDef node;
126 node.set_name(name);
127 for (int n = 0; n < num_inputs; ++n) {
128 const uint32 input_node = rnd.Uniform(i);
129 node.add_input(absl::StrCat(prefix, input_node));
130 }
131
132 *graph.add_node() = std::move(node);
133 }
134
135 return graph;
136 }
137
CreateFaninFanoutNodeGraph(int num_regular_fanins,int num_regular_fanouts,int num_controlling_fanins,int num_controlled_fanouts,bool fanout_unique_index)138 GraphDef CreateFaninFanoutNodeGraph(int num_regular_fanins,
139 int num_regular_fanouts,
140 int num_controlling_fanins,
141 int num_controlled_fanouts,
142 bool fanout_unique_index) {
143 GraphDef graph;
144
145 auto create_node = [](const string& name) {
146 NodeDef node;
147 node.set_name(name);
148 return node;
149 };
150
151 NodeDef node = create_node(/*name=*/"node");
152
153 for (int i = 0; i < num_regular_fanins; ++i) {
154 const string input_node_name = absl::StrFormat("in%05d", i);
155 NodeDef input_node = create_node(/*name=*/input_node_name);
156 *graph.add_node() = std::move(input_node);
157 node.add_input(input_node_name);
158 }
159
160 for (int i = 0; i < num_controlling_fanins; ++i) {
161 const string input_node_name = absl::StrFormat("control_in%05d", i);
162 NodeDef input_node = create_node(/*name=*/input_node_name);
163 *graph.add_node() = std::move(input_node);
164 node.add_input(absl::StrCat("^", input_node_name));
165 }
166
167 for (int i = 0; i < num_regular_fanouts; ++i) {
168 NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i));
169 const string input_node_index =
170 fanout_unique_index ? absl::StrCat(node.name(), ":", i) : node.name();
171 output_node.add_input(input_node_index);
172 *graph.add_node() = std::move(output_node);
173 }
174
175 const string controlled_fanout_input = absl::StrCat("^", node.name());
176 for (int i = 0; i < num_controlled_fanouts; ++i) {
177 NodeDef output_node =
178 create_node(/*name=*/absl::StrFormat("control_out%05d", i));
179 output_node.add_input(controlled_fanout_input);
180 *graph.add_node() = std::move(output_node);
181 }
182
183 *graph.add_node() = std::move(node);
184
185 return graph;
186 }
187
188 } // namespace test
189 } // namespace tensorflow
190
191 #endif // TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_
192