1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include "absl/strings/match.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/ops/data_flow_ops.h"
23 #include "tensorflow/cc/ops/function_ops.h"
24 #include "tensorflow/cc/ops/functional_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/public/version.h"
37 
38 namespace tensorflow {
39 namespace {
40 
ExpectErrorContains(const Status & status,absl::string_view str)41 void ExpectErrorContains(const Status& status, absl::string_view str) {
42   EXPECT_NE(Status::OK(), status);
43   EXPECT_TRUE(absl::StrContains(status.error_message(), str))
44       << "expected error: " << status.error_message() << " to contain: " << str;
45 }
46 
TEST(ValidateConfig,Good)47 TEST(ValidateConfig, Good) {
48   tf2xla::Config config;
49   tf2xla::Feed* feed = config.add_feed();
50   feed->mutable_id()->set_node_name("foo");
51   feed->mutable_id()->set_output_index(123);
52   feed->set_name("foo_debug");
53   feed = config.add_feed();
54   feed->mutable_id()->set_node_name("bar");
55   feed->mutable_id()->set_output_index(0);
56   tf2xla::Fetch* fetch = config.add_fetch();
57   fetch->mutable_id()->set_node_name("baz");
58   fetch->mutable_id()->set_output_index(456);
59   fetch->set_name("baz_debug");
60   fetch = config.add_fetch();
61   fetch->mutable_id()->set_node_name("banana");
62   fetch->mutable_id()->set_output_index(0);
63   TF_EXPECT_OK(ValidateConfig(config));
64 }
65 
TEST(ValidateConfig,BadEmpty)66 TEST(ValidateConfig, BadEmpty) {
67   tf2xla::Config config;
68   ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
69 }
70 
TEST(ValidateConfig,BadNoFetch)71 TEST(ValidateConfig, BadNoFetch) {
72   tf2xla::Config config;
73   tf2xla::Feed* feed = config.add_feed();
74   feed->mutable_id()->set_node_name("foo");
75   ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
76 }
77 
TEST(ValidateConfig,BadFeedNodeName)78 TEST(ValidateConfig, BadFeedNodeName) {
79   tf2xla::Config config;
80   config.add_feed();
81   ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
82 }
83 
TEST(ValidateConfig,BadFeedOutputIndex)84 TEST(ValidateConfig, BadFeedOutputIndex) {
85   tf2xla::Config config;
86   tf2xla::Feed* feed = config.add_feed();
87   feed->mutable_id()->set_node_name("foo");
88   feed->mutable_id()->set_output_index(-1);
89   ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
90 }
91 
TEST(ValidateConfig,BadFetchNodeName)92 TEST(ValidateConfig, BadFetchNodeName) {
93   tf2xla::Config config;
94   tf2xla::Feed* feed = config.add_feed();
95   feed->mutable_id()->set_node_name("foo");
96   config.add_fetch();
97   ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
98 }
99 
TEST(ValidateConfig,BadFetchOutputIndex)100 TEST(ValidateConfig, BadFetchOutputIndex) {
101   tf2xla::Config config;
102   tf2xla::Feed* feed = config.add_feed();
103   feed->mutable_id()->set_node_name("foo");
104   tf2xla::Fetch* fetch = config.add_fetch();
105   fetch->mutable_id()->set_node_name("bar");
106   fetch->mutable_id()->set_output_index(-1);
107   ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
108 }
109 
TEST(ValidateConfig,DuplicateFeedName)110 TEST(ValidateConfig, DuplicateFeedName) {
111   tf2xla::Config config;
112   tf2xla::Feed* feed = config.add_feed();
113   feed->mutable_id()->set_node_name("foo");
114   feed->set_name("dup");
115   feed = config.add_feed();
116   feed->mutable_id()->set_node_name("bar");
117   feed->set_name("dup");
118   ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
119 }
120 
TEST(ValidateConfig,DuplicateFetchName)121 TEST(ValidateConfig, DuplicateFetchName) {
122   tf2xla::Config config;
123   tf2xla::Feed* feed = config.add_feed();
124   feed->mutable_id()->set_node_name("foo");
125   tf2xla::Fetch* fetch = config.add_fetch();
126   fetch->mutable_id()->set_node_name("bar");
127   fetch->set_name("dup");
128   fetch = config.add_fetch();
129   fetch->mutable_id()->set_node_name("baz");
130   fetch->set_name("dup");
131   ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
132 }
133 
TEST(ValidateConfig,ConflictingFeedName)134 TEST(ValidateConfig, ConflictingFeedName) {
135   tf2xla::Config config;
136   tf2xla::Feed* feed = config.add_feed();
137   feed->mutable_id()->set_node_name("foo");
138   feed->set_name("conflict");
139   feed = config.add_feed();
140   feed->mutable_id()->set_node_name("bar");
141   feed->set_name("conflict_data");
142   ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
143 }
144 
TEST(ValidateConfig,ConflictingFetchName)145 TEST(ValidateConfig, ConflictingFetchName) {
146   tf2xla::Config config;
147   tf2xla::Feed* feed = config.add_feed();
148   feed->mutable_id()->set_node_name("foo");
149   tf2xla::Fetch* fetch = config.add_fetch();
150   fetch->mutable_id()->set_node_name("bar");
151   fetch->set_name("conflict");
152   fetch = config.add_fetch();
153   fetch->mutable_id()->set_node_name("baz");
154   fetch->set_name("conflict_data");
155   ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
156 }
157 
FetchesConfig(std::vector<string> fetches)158 static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
159   tf2xla::Config config;
160   for (const auto& fetch_node_name : fetches) {
161     auto* fetch = config.add_fetch();
162     fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
163     fetch->mutable_id()->set_node_name(fetch_node_name);
164   }
165   return config;
166 }
167 
TEST(PruneGraphDefInto,Basic)168 TEST(PruneGraphDefInto, Basic) {
169   GraphDef def;
170   auto* n = def.add_node();
171   n->set_name("a");
172   n->add_input("b:0");
173   n->add_input("^c");
174 
175   GraphDef copy;
176   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, &copy),
177                       "node missing needed");
178   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy),
179                       "node b needed");
180 
181   n = def.add_node();
182   n->set_name("b");
183   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy),
184                       "node c needed");
185   n->add_input("d:1");
186 
187   n = def.add_node();
188   n->set_name("c");
189   n->add_input("d:1");
190 
191   n = def.add_node();
192   n->set_name("d");
193 
194   // Graph is full, no pruning done.
195   // Graph right now has diamond from d:
196   //   d --> b --> a
197   //   d --> c --> a
198   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy));
199   EXPECT_EQ(def.DebugString(), copy.DebugString());
200   GraphDef pruned_a = copy;
201 
202   // Add some unrelated fields that use b and c, but are not needed for a.
203   n = def.add_node();
204   n->set_name("e");
205   n->add_input("^d");
206   n->add_input("b:2");
207   copy.Clear();
208   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy));
209   EXPECT_EQ(pruned_a.DebugString(), copy.DebugString());
210 
211   // Fetch "a" and "e" to get the original graph.
212   copy.Clear();
213   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, &copy));
214   EXPECT_EQ(def.DebugString(), copy.DebugString());
215 }
216 
TEST(SetNodeShardingFromNeighbors,Basic)217 TEST(SetNodeShardingFromNeighbors, Basic) {
218   // Builds a graph that adds two Tensors.
219   Scope scope = Scope::NewRootScope().ExitOnError();
220   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
221   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
222   auto c = ops::Add(scope.WithOpName("C"), a, b);
223   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
224   TF_ASSERT_OK(scope.ToGraph(graph.get()));
225 
226   Node* a_node = nullptr;
227   Node* b_node = nullptr;
228   Node* c_node = nullptr;
229   for (Node* n : graph->nodes()) {
230     if (n->name() == "A") a_node = n;
231     if (n->name() == "B") b_node = n;
232     if (n->name() == "C") c_node = n;
233   }
234 
235   const int num_cores_per_replica = 4;
236 
237   a_node->set_assigned_device_name("foo");
238   EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok());
239 
240   // Test where one input to c_node has a device.
241   a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
242   TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
243   auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
244   TF_ASSERT_OK(parse_status.status());
245   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
246   EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
247 
248   // Test where two inputs to c_node have a device.
249   b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
250   TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
251   parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
252   TF_ASSERT_OK(parse_status.status());
253   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
254   EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
255 
256   // Test setting based on out edges.
257   TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
258   parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
259   TF_ASSERT_OK(parse_status.status());
260   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
261   EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
262 }
263 
264 REGISTER_OP("One")
265     .Output("y: T")
266     .Attr("T: {float, double, int32, int64}")
267     .Doc(R"doc(
268 Returns a tensor with a single element (1) of type T.
269 
270 y: A scalar in type T.
271 
272 )doc");
273 
274 // Tests that CachedFunctionHandles class works.
TEST(CachedFunctionHandles,Basic)275 TEST(CachedFunctionHandles, Basic) {
276   FunctionDef func = FunctionDefHelper::Define(
277       // Name
278       "TestFunc",
279       // Args
280       {},
281       // Return values
282       {"y:T"},
283       // Attr def
284       {"T:{float, double, int32, int64}"},
285       // Nodes
286       {
287           {{"y"}, "One", {}, {{"T", "$T"}}},
288       });
289   FunctionDefLibrary proto;
290   *proto.add_function() = func;
291   FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
292   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
293       new ProcessFunctionLibraryRuntime(
294           /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
295           OptimizerOptions()));
296   FunctionLibraryRuntime* flr =
297       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
298 
299   CachedFunctionHandles cached_function_handles(flr);
300 
301   // Tests that GetOrInstantiate() works.
302   FunctionLibraryRuntime::Handle first_handle;
303   AttrValue attr;
304   attr.set_type(DT_FLOAT);
305   AttrValueMap attrs;
306   attrs["T"] = attr;
307   TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
308       "TestFunc", AttrSlice(&attrs), &first_handle));
309 
310   // Tests that we can get FunctionBody.
311   const FunctionBody* body = flr->GetFunctionBody(first_handle);
312   EXPECT_NE(body, nullptr);
313 
314   // Tests that GetOrInstantiate() returns cached handle when called with same
315   // function name and attributes.
316   FunctionLibraryRuntime::Handle second_handle;
317   TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
318       "TestFunc", AttrSlice(&attrs), &second_handle));
319   EXPECT_EQ(first_handle, second_handle);
320 
321   // Tests that GetOrInstantiate() returns new handle when called with same
322   // function name but different attributes.
323   attr.set_type(DT_INT32);
324   attrs["T"] = attr;
325   FunctionLibraryRuntime::Handle third_handle;
326   TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
327       "TestFunc", AttrSlice(&attrs), &third_handle));
328   EXPECT_NE(first_handle, third_handle);
329 
330   // Tests that ReleaseAllHandles() works.
331   TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
332 }
333 
TEST(PropagateConstIntoFunctionalNodes,WhileLoopWithResourceInput)334 TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
335   FunctionLibraryDefinition fld(OpRegistry::Global(), {});
336   {
337     // Cond graph & body graph.
338     Scope scope = Scope::NewRootScope().ExitOnError();
339     auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
340     auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
341     auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
342     Graph graph(OpRegistry::Global());
343     TF_ASSERT_OK(scope.ToGraph(&graph));
344     FunctionDef cond_fdef;
345     TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
346     TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
347     FunctionDef body_fdef;
348     TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
349     TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
350   }
351   Scope scope = Scope::NewRootScope().ExitOnError();
352   auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
353   auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
354   NameAttrList cond_fn, body_fn;
355   cond_fn.set_name("cond");
356   body_fn.set_name("body");
357   auto while_op =
358       ops::While(scope.WithOpName("while"),
359                  std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
360   Graph graph(OpRegistry::Global());
361   TF_ASSERT_OK(scope.ToGraph(&graph));
362 
363   TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
364 }
365 
TEST(PropagateConstIntoFunctionalNodes,CopiedConstNodeHasUniqueName)366 TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) {
367   FunctionLibraryDefinition fld(OpRegistry::Global(), {});
368   {
369     // Cond graph & body graph.
370     Scope scope = Scope::NewRootScope().ExitOnError();
371     auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
372     auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1);
373     auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name"));
374     auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
375     Graph graph(OpRegistry::Global());
376     TF_ASSERT_OK(scope.ToGraph(&graph));
377     FunctionDef cond_fdef;
378     TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
379     TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
380     FunctionDef body_fdef;
381     TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
382     TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
383   }
384   Scope scope = Scope::NewRootScope().ExitOnError();
385   auto pred =
386       ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({}));
387   auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({}));
388   NameAttrList cond_fn, body_fn;
389   cond_fn.set_name("cond");
390   body_fn.set_name("body");
391   auto while_op =
392       ops::While(scope.WithOpName("while"),
393                  std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
394   Graph graph(OpRegistry::Global());
395   TF_ASSERT_OK(scope.ToGraph(&graph));
396 
397   TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
398 
399   // Check that in rewritten body function, the NoOp node still has name
400   // "duplicate_name", and the copied Const node has name "duplicate_name/_0".
401   auto node_name_index = graph.BuildNodeNameIndex();
402   Node* while_node = node_name_index["while"];
403   ASSERT_NE(while_node, nullptr);
404   TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn));
405   const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name());
406   ASSERT_NE(rewritten_body_fn, nullptr);
407   std::unordered_map<string, NodeDef> nodes;
408   for (const NodeDef& node_def : rewritten_body_fn->node_def()) {
409     nodes[node_def.name()] = node_def;
410   }
411   auto noop_def = nodes.find("duplicate_name");
412   ASSERT_NE(noop_def, nodes.end());
413   EXPECT_EQ(noop_def->second.op(), "NoOp");
414   auto const_def = nodes.find("duplicate_name/_0");
415   ASSERT_NE(const_def, nodes.end());
416   EXPECT_EQ(const_def->second.op(), "Const");
417 }
418 
419 }  // namespace
420 }  // namespace tensorflow
421