1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include "absl/strings/match.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/ops/data_flow_ops.h"
23 #include "tensorflow/cc/ops/function_ops.h"
24 #include "tensorflow/cc/ops/functional_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/public/version.h"
37
38 namespace tensorflow {
39 namespace {
40
ExpectErrorContains(const Status & status,absl::string_view str)41 void ExpectErrorContains(const Status& status, absl::string_view str) {
42 EXPECT_NE(Status::OK(), status);
43 EXPECT_TRUE(absl::StrContains(status.error_message(), str))
44 << "expected error: " << status.error_message() << " to contain: " << str;
45 }
46
TEST(ValidateConfig,Good)47 TEST(ValidateConfig, Good) {
48 tf2xla::Config config;
49 tf2xla::Feed* feed = config.add_feed();
50 feed->mutable_id()->set_node_name("foo");
51 feed->mutable_id()->set_output_index(123);
52 feed->set_name("foo_debug");
53 feed = config.add_feed();
54 feed->mutable_id()->set_node_name("bar");
55 feed->mutable_id()->set_output_index(0);
56 tf2xla::Fetch* fetch = config.add_fetch();
57 fetch->mutable_id()->set_node_name("baz");
58 fetch->mutable_id()->set_output_index(456);
59 fetch->set_name("baz_debug");
60 fetch = config.add_fetch();
61 fetch->mutable_id()->set_node_name("banana");
62 fetch->mutable_id()->set_output_index(0);
63 TF_EXPECT_OK(ValidateConfig(config));
64 }
65
TEST(ValidateConfig,BadEmpty)66 TEST(ValidateConfig, BadEmpty) {
67 tf2xla::Config config;
68 ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
69 }
70
TEST(ValidateConfig,BadNoFetch)71 TEST(ValidateConfig, BadNoFetch) {
72 tf2xla::Config config;
73 tf2xla::Feed* feed = config.add_feed();
74 feed->mutable_id()->set_node_name("foo");
75 ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
76 }
77
TEST(ValidateConfig,BadFeedNodeName)78 TEST(ValidateConfig, BadFeedNodeName) {
79 tf2xla::Config config;
80 config.add_feed();
81 ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
82 }
83
TEST(ValidateConfig,BadFeedOutputIndex)84 TEST(ValidateConfig, BadFeedOutputIndex) {
85 tf2xla::Config config;
86 tf2xla::Feed* feed = config.add_feed();
87 feed->mutable_id()->set_node_name("foo");
88 feed->mutable_id()->set_output_index(-1);
89 ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
90 }
91
TEST(ValidateConfig,BadFetchNodeName)92 TEST(ValidateConfig, BadFetchNodeName) {
93 tf2xla::Config config;
94 tf2xla::Feed* feed = config.add_feed();
95 feed->mutable_id()->set_node_name("foo");
96 config.add_fetch();
97 ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
98 }
99
TEST(ValidateConfig,BadFetchOutputIndex)100 TEST(ValidateConfig, BadFetchOutputIndex) {
101 tf2xla::Config config;
102 tf2xla::Feed* feed = config.add_feed();
103 feed->mutable_id()->set_node_name("foo");
104 tf2xla::Fetch* fetch = config.add_fetch();
105 fetch->mutable_id()->set_node_name("bar");
106 fetch->mutable_id()->set_output_index(-1);
107 ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
108 }
109
TEST(ValidateConfig,DuplicateFeedName)110 TEST(ValidateConfig, DuplicateFeedName) {
111 tf2xla::Config config;
112 tf2xla::Feed* feed = config.add_feed();
113 feed->mutable_id()->set_node_name("foo");
114 feed->set_name("dup");
115 feed = config.add_feed();
116 feed->mutable_id()->set_node_name("bar");
117 feed->set_name("dup");
118 ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
119 }
120
TEST(ValidateConfig,DuplicateFetchName)121 TEST(ValidateConfig, DuplicateFetchName) {
122 tf2xla::Config config;
123 tf2xla::Feed* feed = config.add_feed();
124 feed->mutable_id()->set_node_name("foo");
125 tf2xla::Fetch* fetch = config.add_fetch();
126 fetch->mutable_id()->set_node_name("bar");
127 fetch->set_name("dup");
128 fetch = config.add_fetch();
129 fetch->mutable_id()->set_node_name("baz");
130 fetch->set_name("dup");
131 ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
132 }
133
TEST(ValidateConfig,ConflictingFeedName)134 TEST(ValidateConfig, ConflictingFeedName) {
135 tf2xla::Config config;
136 tf2xla::Feed* feed = config.add_feed();
137 feed->mutable_id()->set_node_name("foo");
138 feed->set_name("conflict");
139 feed = config.add_feed();
140 feed->mutable_id()->set_node_name("bar");
141 feed->set_name("conflict_data");
142 ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
143 }
144
TEST(ValidateConfig,ConflictingFetchName)145 TEST(ValidateConfig, ConflictingFetchName) {
146 tf2xla::Config config;
147 tf2xla::Feed* feed = config.add_feed();
148 feed->mutable_id()->set_node_name("foo");
149 tf2xla::Fetch* fetch = config.add_fetch();
150 fetch->mutable_id()->set_node_name("bar");
151 fetch->set_name("conflict");
152 fetch = config.add_fetch();
153 fetch->mutable_id()->set_node_name("baz");
154 fetch->set_name("conflict_data");
155 ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
156 }
157
FetchesConfig(std::vector<string> fetches)158 static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
159 tf2xla::Config config;
160 for (const auto& fetch_node_name : fetches) {
161 auto* fetch = config.add_fetch();
162 fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
163 fetch->mutable_id()->set_node_name(fetch_node_name);
164 }
165 return config;
166 }
167
TEST(PruneGraphDefInto,Basic)168 TEST(PruneGraphDefInto, Basic) {
169 GraphDef def;
170 auto* n = def.add_node();
171 n->set_name("a");
172 n->add_input("b:0");
173 n->add_input("^c");
174
175 GraphDef copy;
176 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©),
177 "node missing needed");
178 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
179 "node b needed");
180
181 n = def.add_node();
182 n->set_name("b");
183 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
184 "node c needed");
185 n->add_input("d:1");
186
187 n = def.add_node();
188 n->set_name("c");
189 n->add_input("d:1");
190
191 n = def.add_node();
192 n->set_name("d");
193
194 // Graph is full, no pruning done.
195 // Graph right now has diamond from d:
196 // d --> b --> a
197 // d --> c --> a
198 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
199 EXPECT_EQ(def.DebugString(), copy.DebugString());
200 GraphDef pruned_a = copy;
201
202 // Add some unrelated fields that use b and c, but are not needed for a.
203 n = def.add_node();
204 n->set_name("e");
205 n->add_input("^d");
206 n->add_input("b:2");
207 copy.Clear();
208 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
209 EXPECT_EQ(pruned_a.DebugString(), copy.DebugString());
210
211 // Fetch "a" and "e" to get the original graph.
212 copy.Clear();
213 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©));
214 EXPECT_EQ(def.DebugString(), copy.DebugString());
215 }
216
TEST(SetNodeShardingFromNeighbors,Basic)217 TEST(SetNodeShardingFromNeighbors, Basic) {
218 // Builds a graph that adds two Tensors.
219 Scope scope = Scope::NewRootScope().ExitOnError();
220 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
221 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
222 auto c = ops::Add(scope.WithOpName("C"), a, b);
223 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
224 TF_ASSERT_OK(scope.ToGraph(graph.get()));
225
226 Node* a_node = nullptr;
227 Node* b_node = nullptr;
228 Node* c_node = nullptr;
229 for (Node* n : graph->nodes()) {
230 if (n->name() == "A") a_node = n;
231 if (n->name() == "B") b_node = n;
232 if (n->name() == "C") c_node = n;
233 }
234
235 const int num_cores_per_replica = 4;
236
237 a_node->set_assigned_device_name("foo");
238 EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok());
239
240 // Test where one input to c_node has a device.
241 a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
242 TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
243 auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
244 TF_ASSERT_OK(parse_status.status());
245 ASSERT_TRUE(parse_status.ValueOrDie().has_value());
246 EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
247
248 // Test where two inputs to c_node have a device.
249 b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
250 TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
251 parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
252 TF_ASSERT_OK(parse_status.status());
253 ASSERT_TRUE(parse_status.ValueOrDie().has_value());
254 EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
255
256 // Test setting based on out edges.
257 TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
258 parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
259 TF_ASSERT_OK(parse_status.status());
260 ASSERT_TRUE(parse_status.ValueOrDie().has_value());
261 EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
262 }
263
264 REGISTER_OP("One")
265 .Output("y: T")
266 .Attr("T: {float, double, int32, int64}")
267 .Doc(R"doc(
268 Returns a tensor with a single element (1) of type T.
269
270 y: A scalar in type T.
271
272 )doc");
273
274 // Tests that CachedFunctionHandles class works.
TEST(CachedFunctionHandles,Basic)275 TEST(CachedFunctionHandles, Basic) {
276 FunctionDef func = FunctionDefHelper::Define(
277 // Name
278 "TestFunc",
279 // Args
280 {},
281 // Return values
282 {"y:T"},
283 // Attr def
284 {"T:{float, double, int32, int64}"},
285 // Nodes
286 {
287 {{"y"}, "One", {}, {{"T", "$T"}}},
288 });
289 FunctionDefLibrary proto;
290 *proto.add_function() = func;
291 FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
292 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
293 new ProcessFunctionLibraryRuntime(
294 /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
295 OptimizerOptions()));
296 FunctionLibraryRuntime* flr =
297 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
298
299 CachedFunctionHandles cached_function_handles(flr);
300
301 // Tests that GetOrInstantiate() works.
302 FunctionLibraryRuntime::Handle first_handle;
303 AttrValue attr;
304 attr.set_type(DT_FLOAT);
305 AttrValueMap attrs;
306 attrs["T"] = attr;
307 TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
308 "TestFunc", AttrSlice(&attrs), &first_handle));
309
310 // Tests that we can get FunctionBody.
311 const FunctionBody* body = flr->GetFunctionBody(first_handle);
312 EXPECT_NE(body, nullptr);
313
314 // Tests that GetOrInstantiate() returns cached handle when called with same
315 // function name and attributes.
316 FunctionLibraryRuntime::Handle second_handle;
317 TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
318 "TestFunc", AttrSlice(&attrs), &second_handle));
319 EXPECT_EQ(first_handle, second_handle);
320
321 // Tests that GetOrInstantiate() returns new handle when called with same
322 // function name but different attributes.
323 attr.set_type(DT_INT32);
324 attrs["T"] = attr;
325 FunctionLibraryRuntime::Handle third_handle;
326 TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
327 "TestFunc", AttrSlice(&attrs), &third_handle));
328 EXPECT_NE(first_handle, third_handle);
329
330 // Tests that ReleaseAllHandles() works.
331 TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
332 }
333
TEST(PropagateConstIntoFunctionalNodes,WhileLoopWithResourceInput)334 TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
335 FunctionLibraryDefinition fld(OpRegistry::Global(), {});
336 {
337 // Cond graph & body graph.
338 Scope scope = Scope::NewRootScope().ExitOnError();
339 auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
340 auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
341 auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
342 Graph graph(OpRegistry::Global());
343 TF_ASSERT_OK(scope.ToGraph(&graph));
344 FunctionDef cond_fdef;
345 TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
346 TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
347 FunctionDef body_fdef;
348 TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
349 TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
350 }
351 Scope scope = Scope::NewRootScope().ExitOnError();
352 auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
353 auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
354 NameAttrList cond_fn, body_fn;
355 cond_fn.set_name("cond");
356 body_fn.set_name("body");
357 auto while_op =
358 ops::While(scope.WithOpName("while"),
359 std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
360 Graph graph(OpRegistry::Global());
361 TF_ASSERT_OK(scope.ToGraph(&graph));
362
363 TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
364 }
365
TEST(PropagateConstIntoFunctionalNodes,CopiedConstNodeHasUniqueName)366 TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) {
367 FunctionLibraryDefinition fld(OpRegistry::Global(), {});
368 {
369 // Cond graph & body graph.
370 Scope scope = Scope::NewRootScope().ExitOnError();
371 auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
372 auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1);
373 auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name"));
374 auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
375 Graph graph(OpRegistry::Global());
376 TF_ASSERT_OK(scope.ToGraph(&graph));
377 FunctionDef cond_fdef;
378 TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
379 TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
380 FunctionDef body_fdef;
381 TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
382 TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
383 }
384 Scope scope = Scope::NewRootScope().ExitOnError();
385 auto pred =
386 ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({}));
387 auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({}));
388 NameAttrList cond_fn, body_fn;
389 cond_fn.set_name("cond");
390 body_fn.set_name("body");
391 auto while_op =
392 ops::While(scope.WithOpName("while"),
393 std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
394 Graph graph(OpRegistry::Global());
395 TF_ASSERT_OK(scope.ToGraph(&graph));
396
397 TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
398
399 // Check that in rewritten body function, the NoOp node still has name
400 // "duplicate_name", and the copied Const node has name "duplicate_name/_0".
401 auto node_name_index = graph.BuildNodeNameIndex();
402 Node* while_node = node_name_index["while"];
403 ASSERT_NE(while_node, nullptr);
404 TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn));
405 const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name());
406 ASSERT_NE(rewritten_body_fn, nullptr);
407 std::unordered_map<string, NodeDef> nodes;
408 for (const NodeDef& node_def : rewritten_body_fn->node_def()) {
409 nodes[node_def.name()] = node_def;
410 }
411 auto noop_def = nodes.find("duplicate_name");
412 ASSERT_NE(noop_def, nodes.end());
413 EXPECT_EQ(noop_def->second.op(), "NoOp");
414 auto const_def = nodes.find("duplicate_name/_0");
415 ASSERT_NE(const_def, nodes.end());
416 EXPECT_EQ(const_def->second.op(), "Const");
417 }
418
419 } // namespace
420 } // namespace tensorflow
421