1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/grappler_test.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
25 #include "tensorflow/core/public/session.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
30 namespace {
CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef> * want,protobuf::RepeatedPtrField<NodeDef> * got)31 void CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef>* want,
32 protobuf::RepeatedPtrField<NodeDef>* got) {
33 auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
34 return n1.name() < n2.name();
35 };
36
37 std::sort(want->begin(), want->end(), comparator);
38 std::sort(got->begin(), got->end(), comparator);
39
40 ASSERT_EQ(want->size(), got->size());
41
42 for (int i = 0; i < want->size(); ++i) {
43 NodeDef& want_node = (*want)[i];
44 NodeDef& got_node = (*got)[i];
45
46 EXPECT_EQ(want_node.op(), got_node.op());
47 EXPECT_EQ(want_node.name(), got_node.name());
48 EXPECT_EQ(want_node.device(), got_node.device());
49 ASSERT_EQ(want_node.input_size(), got_node.input_size());
50
51 // Order of control dependencies doesn't matter, so we sort them first.
52 const auto is_control = [](const string& input) -> bool {
53 return ParseTensorName(input).index() < 0;
54 };
55
56 auto want_inputs = want_node.mutable_input();
57 auto got_inputs = got_node.mutable_input();
58 std::sort(absl::c_find_if(*want_inputs, is_control), want_inputs->end());
59 std::sort(absl::c_find_if(*got_inputs, is_control), got_inputs->end());
60
61 for (int j = 0; j < want_node.input_size(); ++j) {
62 const TensorId want_tensor = ParseTensorName(want_node.input(j));
63 const TensorId got_tensor = ParseTensorName(got_node.input(j));
64 EXPECT_EQ(want_tensor.ToString(), got_tensor.ToString());
65 }
66 }
67 }
68 } // namespace
69
GrapplerTest()70 GrapplerTest::GrapplerTest() {
71 // Turn off all the automatic optimizations to ensure that we run the graph
72 // exactly as it is given to us. This ensures that we can compare the results
73 // before and after manual optimization, without any of the automatic
74 // optimizations interfering in the comparison.
75 RewriterConfig* cfg =
76 options_.config.mutable_graph_options()->mutable_rewrite_options();
77 // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
78 // off.
79 cfg->set_arithmetic_optimization(RewriterConfig::OFF);
80 cfg->set_constant_folding(RewriterConfig::OFF);
81 cfg->set_debug_stripper(RewriterConfig::OFF);
82 cfg->set_dependency_optimization(RewriterConfig::OFF);
83 cfg->set_function_optimization(RewriterConfig::OFF);
84 cfg->set_implementation_selector(RewriterConfig::OFF);
85 cfg->set_layout_optimizer(RewriterConfig::OFF);
86 cfg->set_loop_optimization(RewriterConfig::OFF);
87 cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
88 }
89
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names) const90 std::vector<Tensor> GrapplerTest::EvaluateNodes(
91 const GraphDef& graph, const std::vector<string>& node_names) const {
92 return EvaluateNodes(graph, node_names, {});
93 }
94
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names,const std::vector<std::pair<string,Tensor>> & inputs) const95 std::vector<Tensor> GrapplerTest::EvaluateNodes(
96 const GraphDef& graph, const std::vector<string>& node_names,
97 const std::vector<std::pair<string, Tensor>>& inputs) const {
98 std::unique_ptr<tensorflow::Session> session(NewSession(options_));
99 TF_CHECK_OK(session->Create(graph));
100 RunOptions run_options;
101 std::vector<Tensor> output_tensors;
102 TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
103 &output_tensors, nullptr));
104 TF_CHECK_OK(session->Close());
105 return output_tensors;
106 }
107
EvaluateFetchNodes(const GrapplerItem & item) const108 std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
109 const GrapplerItem& item) const {
110 std::unique_ptr<tensorflow::Session> session(NewSession(options_));
111 TF_CHECK_OK(session->Create(item.graph));
112 RunOptions run_options;
113 if (!item.init_ops.empty()) {
114 std::vector<Tensor> dummy;
115 TF_CHECK_OK(
116 session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr));
117 }
118 std::vector<Tensor> output_tensors;
119 TF_CHECK_OK(session->Run(run_options, item.feed, item.fetch, {},
120 &output_tensors, nullptr));
121 TF_CHECK_OK(session->Close());
122 return output_tensors;
123 }
124
AddNode(const string & name,const string & op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,GraphDef * graph) const125 NodeDef* GrapplerTest::AddNode(
126 const string& name, const string& op, const std::vector<string>& inputs,
127 const std::vector<std::pair<string, AttrValue>>& attributes,
128 GraphDef* graph) const {
129 NodeDef* node = graph->add_node();
130 node->set_name(name);
131 node->set_op(op);
132 for (const string& input : inputs) {
133 node->add_input(input);
134 }
135 for (auto attr : attributes) {
136 (*node->mutable_attr())[attr.first] = attr.second;
137 }
138 return node;
139 }
140
CompareGraphs(GraphDef want,GraphDef got) const141 void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
142 CompareGraphNodes(want.mutable_node(), got.mutable_node());
143 }
144
CompareFunctions(FunctionDef want,FunctionDef got) const145 void GrapplerTest::CompareFunctions(FunctionDef want, FunctionDef got) const {
146 CompareGraphNodes(want.mutable_node_def(), got.mutable_node_def());
147 }
148
CompareNodes(const NodeDef & want,const NodeDef & got) const149 void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const {
150 EXPECT_EQ(want.name(), got.name());
151 EXPECT_EQ(want.op(), got.op());
152
153 std::vector<string> want_inputs(want.input().begin(), want.input().end());
154 std::vector<string> got_inputs(got.input().begin(), got.input().end());
155 EXPECT_EQ(want_inputs, got_inputs);
156
157 const auto attr_name = [](const std::pair<const string, AttrValue>& attr) {
158 return attr.first;
159 };
160
161 std::vector<string> want_attrs;
162 std::vector<string> got_attrs;
163 absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name);
164 absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name);
165 absl::c_sort(want_attrs);
166 absl::c_sort(got_attrs);
167 EXPECT_EQ(want_attrs, got_attrs);
168
169 for (const string& attr : want_attrs) {
170 EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr)));
171 }
172 }
173
IsNodesDirectlyConnected(const NodeMap & node_map,const string & src,const string & dst,int position)174 bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
175 const string& src,
176 const string& dst, int position) {
177 const NodeDef* src_node = node_map.GetNode(src);
178 const NodeDef* dst_node = node_map.GetNode(dst);
179 EXPECT_TRUE(src_node != nullptr) << src << " node not found";
180 EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
181 return src_node && dst_node && dst_node->input(position) == src_node->name();
182 }
183
CountOpNodes(const GraphDef & graph,const string & op)184 int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
185 return std::count_if(graph.node().begin(), graph.node().end(),
186 [&op](const NodeDef& node) { return node.op() == op; });
187 }
188
189 } // namespace grappler
190 } // namespace tensorflow
191