1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/graph_def_util.h"
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27 
28 namespace tensorflow {
29 namespace {
30 
FinalizeOpDef(const OpDefBuilder & b,OpDef * op_def)31 Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
32   OpRegistrationData op_reg_data;
33   const Status s = b.Finalize(&op_reg_data);
34   *op_def = op_reg_data.op_def;
35   return s;
36 }
37 
38 // Producer and consumer have default for an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeWithDefault)39 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
40   OpList op_list;
41   TF_ASSERT_OK(
42       FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
43                     op_list.add_op()));
44   OpListOpRegistry registry(&op_list);
45 
46   GraphDef graph_def;
47   TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", &registry)
48                    .Finalize(graph_def.add_node()));
49   GraphDef expected_graph_def = graph_def;
50 
51   std::set<std::pair<string, string>> op_attr_removed;
52   TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
53                                                  &op_attr_removed));
54 
55   TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
56   EXPECT_TRUE(op_attr_removed.empty());
57 }
58 
59 // Producer and consumer both have an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeNoDefault)60 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
61   OpList op_list;
62   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
63                              op_list.add_op()));
64   OpListOpRegistry registry(&op_list);
65 
66   GraphDef graph_def;
67   TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", &registry)
68                    .Attr("a", 42)
69                    .Finalize(graph_def.add_node()));
70   GraphDef expected_graph_def = graph_def;
71 
72   std::set<std::pair<string, string>> op_attr_removed;
73   TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
74                                                  &op_attr_removed));
75 
76   TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
77   EXPECT_TRUE(op_attr_removed.empty());
78 }
79 
80 // Producer has default for an attr that the consumer does not know
81 // about, and the produced graph has the default value for the attr ->
82 // attr removed from graph (and so able to be consumed).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UsesDefault)83 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
84   OpList consumer_op_list;
85   TF_ASSERT_OK(
86       FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
87   OpListOpRegistry consumer_registry(&consumer_op_list);
88 
89   OpList producer_op_list;
90   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
91                              producer_op_list.add_op()));
92   OpListOpRegistry producer_registry(&producer_op_list);
93 
94   GraphDef produced_graph_def;
95   TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
96                    .Finalize(produced_graph_def.add_node()));
97 
98   std::set<std::pair<string, string>> op_attr_removed;
99   TF_ASSERT_OK(
100       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
101                                         producer_registry, &op_attr_removed));
102 
103   GraphDef expected_graph_def;
104   TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
105                    .Finalize(expected_graph_def.add_node()));
106   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
107 
108   std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
109   EXPECT_EQ(expected_removed, op_attr_removed);
110 }
111 
112 // Producer has default for an attr that the consumer does not know
113 // about, graph sets the attr to a value different from the default ->
114 // graph unchanged (but not able to be consumed by consumer).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,ChangedFromDefault)115 TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
116   OpList consumer_op_list;
117   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
118                              consumer_op_list.add_op()));
119   OpListOpRegistry consumer_registry(&consumer_op_list);
120 
121   OpList producer_op_list;
122   TF_ASSERT_OK(
123       FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
124                     producer_op_list.add_op()));
125   OpListOpRegistry producer_registry(&producer_op_list);
126 
127   GraphDef produced_graph_def;
128   TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
129                               &producer_registry)
130                    .Attr("a", 9)
131                    .Finalize(produced_graph_def.add_node()));
132   GraphDef expected_graph_def = produced_graph_def;
133 
134   std::set<std::pair<string, string>> op_attr_removed;
135   TF_ASSERT_OK(
136       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
137                                         producer_registry, &op_attr_removed));
138 
139   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
140   EXPECT_TRUE(op_attr_removed.empty());
141 }
142 
143 // Attrs starting with underscores should not be removed.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UnderscoreAttrs)144 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
145   OpList consumer_op_list;
146   TF_ASSERT_OK(
147       FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
148   OpListOpRegistry consumer_registry(&consumer_op_list);
149 
150   OpList producer_op_list;
151   TF_ASSERT_OK(
152       FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
153   // Add the _underscore attr manually since OpDefBuilder would complain
154   OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
155   attr->set_name("_underscore");
156   attr->set_type("int");
157   attr->mutable_default_value()->set_i(17);
158   OpListOpRegistry producer_registry(&producer_op_list);
159 
160   GraphDef produced_graph_def;
161   TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
162                    .Attr("_underscore", 17)
163                    .Finalize(produced_graph_def.add_node()));
164   GraphDef expected_graph_def = produced_graph_def;
165 
166   std::set<std::pair<string, string>> op_attr_removed;
167   TF_ASSERT_OK(
168       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
169                                         producer_registry, &op_attr_removed));
170 
171   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
172   EXPECT_EQ(op_attr_removed.size(), 0);
173 }
174 
TEST(RemoveNewDefaultAttrsFromGraphDefTest,HasFunction)175 TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
176   OpList consumer_op_list;
177   TF_ASSERT_OK(
178       FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
179   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
180                              consumer_op_list.add_op()));
181   OpListOpRegistry consumer_registry(&consumer_op_list);
182 
183   OpList producer_op_list;
184   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
185                              producer_op_list.add_op()));
186   TF_ASSERT_OK(
187       FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
188                     producer_op_list.add_op()));
189   OpListOpRegistry producer_registry(&producer_op_list);
190 
191   GraphDef produced_graph_def;
192   *produced_graph_def.mutable_library()->add_function() =
193       FunctionDefHelper::Create(
194           "my_func", {}, {}, {},
195           {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
196            {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
197           {});
198   OpList function_op_list;
199   *function_op_list.add_op() =
200       produced_graph_def.library().function(0).signature();
201   OpListOpRegistry function_registry(&function_op_list);
202   TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
203                    .Finalize(produced_graph_def.add_node()));
204 
205   std::set<std::pair<string, string>> op_attr_removed;
206   TF_ASSERT_OK(
207       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
208                                         producer_registry, &op_attr_removed));
209 
210   GraphDef expected_graph_def;
211   *expected_graph_def.mutable_library()->add_function() =
212       FunctionDefHelper::Create(
213           "my_func", {}, {}, {},
214           {{{"x"}, "UsesDefault", {}, {}},
215            {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
216           {});
217   TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
218                    .Finalize(expected_graph_def.add_node()));
219   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
220   EXPECT_EQ(expected_graph_def.library().DebugString(),
221             produced_graph_def.library().DebugString());
222 
223   std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
224   EXPECT_EQ(expected_removed, op_attr_removed);
225 }
226 
TEST(StrippedOpListForGraphTest,FlatTest)227 TEST(StrippedOpListForGraphTest, FlatTest) {
228   // Make four ops
229   OpList op_list;
230   for (const string& op : {"A", "B", "C", "D"}) {
231     OpDef* op_def = op_list.add_op();
232     op_def->set_name(op);
233     op_def->set_summary("summary");
234     op_def->set_description("description");
235     op_def->set_is_commutative(op == "B");
236   }
237 
238   // Make a graph which uses two ops once and twice, respectively.
239   // The result should be independent of the ordering.
240   const string graph_ops[4][3] = {
241       {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
242   for (const bool use_function : {false, true}) {
243     for (int order = 0; order < 4; order++) {
244       GraphDef graph_def;
245       if (use_function) {
246         FunctionDef* function_def = graph_def.mutable_library()->add_function();
247         function_def->mutable_signature()->set_name("F");
248         for (const string& op : graph_ops[order]) {
249           function_def->add_node_def()->set_op(op);
250         }
251         graph_def.add_node()->set_op("F");
252       } else {
253         for (const string& op : graph_ops[order]) {
254           string name = strings::StrCat("name", graph_def.node_size());
255           NodeDef* node = graph_def.add_node();
256           node->set_name(name);
257           node->set_op(op);
258         }
259       }
260 
261       // Strip the op list
262       OpList stripped_op_list;
263       TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
264                                           &stripped_op_list));
265 
266       // We should have exactly two ops: B and C.
267       ASSERT_EQ(stripped_op_list.op_size(), 2);
268       for (int i = 0; i < 2; i++) {
269         const OpDef& op = stripped_op_list.op(i);
270         EXPECT_EQ(op.name(), i ? "C" : "B");
271         EXPECT_EQ(op.summary(), "");
272         EXPECT_EQ(op.description(), "");
273         EXPECT_EQ(op.is_commutative(), !i);
274       }
275 
276       // Should get the same result using OpsUsedByGraph().
277       std::set<string> used_ops;
278       OpsUsedByGraph(graph_def, &used_ops);
279       ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
280     }
281   }
282 }
283 
TEST(StrippedOpListForGraphTest,NestedFunctionTest)284 TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
285   // Make a primitive op A.
286   OpList op_list;
287   op_list.add_op()->set_name("A");
288 
289   for (const bool recursive : {false, true}) {
290     // Call A from function B, and B from function C.
291     GraphDef graph_def;
292     FunctionDef* b = graph_def.mutable_library()->add_function();
293     FunctionDef* c = graph_def.mutable_library()->add_function();
294     b->mutable_signature()->set_name("B");
295     c->mutable_signature()->set_name("C");
296     b->add_node_def()->set_op("A");
297     c->add_node_def()->set_op("B");
298     if (recursive) {
299       b->add_node_def()->set_op("B");
300       c->add_node_def()->set_op("C");
301     }
302 
303     // Use C in the graph.
304     graph_def.add_node()->set_op("C");
305 
306     // The stripped op list should contain just A.
307     OpList stripped_op_list;
308     TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
309                                         &stripped_op_list));
310     ASSERT_EQ(stripped_op_list.op_size(), 1);
311     ASSERT_EQ(stripped_op_list.op(0).name(), "A");
312 
313     // Should get the same result using OpsUsedByGraph().
314     std::set<string> used_ops;
315     OpsUsedByGraph(graph_def, &used_ops);
316     ASSERT_EQ(std::set<string>({"A"}), used_ops);
317   }
318 }
319 
320 }  // namespace
321 }  // namespace tensorflow
322