1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/resource_variable_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_constructor.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/graph/graph_def_builder_util.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/platform/test.h"
38 
39 namespace tensorflow {
40 namespace {
41 
MakeRead(const Scope & scope,const string & id)42 Node* MakeRead(const Scope& scope, const string& id) {
43   Output var_handle =
44       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
45   Output read =
46       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
47   return read.node();
48 }
49 
MakeWrite(const Scope & scope,const string & id)50 Node* MakeWrite(const Scope& scope, const string& id) {
51   Output var_handle =
52       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
53   Output value_to_write =
54       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
55   ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
56                                   value_to_write);
57   return assign_op.operation.node();
58 }
59 
MakeModify(const Scope & scope,const string & id)60 Node* MakeModify(const Scope& scope, const string& id) {
61   Output var_handle =
62       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
63   Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
64   ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
65                                          var_handle, value_to_write);
66   return assign_add_op.operation.node();
67 }
68 
MakeNeutral(const Scope & scope,const string & id)69 Node* MakeNeutral(const Scope& scope, const string& id) {
70   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
71 }
72 
ComputeIncompatiblePairs(Graph * g,std::vector<std::pair<int,int>> * result)73 Status ComputeIncompatiblePairs(Graph* g,
74                                 std::vector<std::pair<int, int>>* result) {
75   FixupSourceAndSinkEdges(g);
76   return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
77                                                    result);
78 }
79 
TEST(ResourceOperationSafetyAnalysisTest,WriteRead)80 TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
81   Scope root = Scope::NewRootScope().ExitOnError();
82 
83   Node* read = MakeRead(root, "R");
84   Node* write = MakeWrite(root, "W");
85 
86   root.graph()->AddControlEdge(write, read);
87 
88   std::vector<std::pair<int, int>> incompatible_pairs;
89   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
90 
91   ASSERT_EQ(incompatible_pairs.size(), 1);
92   std::pair<int, int> write_read_pair = {write->id(), read->id()};
93   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
94 }
95 
TEST(ResourceOperationSafetyAnalysisTest,ReadWrite)96 TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
97   Scope root = Scope::NewRootScope().ExitOnError();
98 
99   Node* read = MakeRead(root, "R");
100   Node* write = MakeWrite(root, "W");
101 
102   root.graph()->AddControlEdge(read, write);
103 
104   std::vector<std::pair<int, int>> incompatible_pairs;
105   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
106 
107   EXPECT_EQ(incompatible_pairs.size(), 0);
108 }
109 
TEST(ResourceOperationSafetyAnalysisTest,ReadWriteNoEdges)110 TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
111   Scope root = Scope::NewRootScope().ExitOnError();
112 
113   MakeRead(root, "R");
114   MakeWrite(root, "W");
115 
116   std::vector<std::pair<int, int>> incompatible_pairs;
117   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
118 
119   EXPECT_EQ(incompatible_pairs.size(), 0);
120 }
121 
TEST(ResourceOperationSafetyAnalysisTest,ReadModify)122 TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
123   Scope root = Scope::NewRootScope().ExitOnError();
124 
125   Node* read = MakeRead(root, "R");
126   Node* modify = MakeModify(root, "M");
127 
128   root.graph()->AddControlEdge(read, modify);
129 
130   std::vector<std::pair<int, int>> incompatible_pairs;
131   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
132 
133   EXPECT_EQ(incompatible_pairs.size(), 0);
134 }
135 
TEST(ResourceOperationSafetyAnalysisTest,ModifyRead)136 TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
137   Scope root = Scope::NewRootScope().ExitOnError();
138 
139   Node* read = MakeRead(root, "R");
140   Node* modify = MakeModify(root, "M");
141 
142   root.graph()->AddControlEdge(modify, read);
143 
144   std::vector<std::pair<int, int>> incompatible_pairs;
145   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
146 
147   ASSERT_EQ(incompatible_pairs.size(), 1);
148   std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
149   EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
150 }
151 
TEST(ResourceOperationSafetyAnalysisTest,ModifyWrite)152 TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
153   Scope root = Scope::NewRootScope().ExitOnError();
154 
155   Node* modify = MakeModify(root, "M");
156   Node* write = MakeWrite(root, "W");
157 
158   root.graph()->AddControlEdge(modify, write);
159 
160   std::vector<std::pair<int, int>> incompatible_pairs;
161   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
162 
163   EXPECT_EQ(incompatible_pairs.size(), 0);
164 }
165 
TEST(ResourceOperationSafetyAnalysisTest,WriteModify)166 TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
167   Scope root = Scope::NewRootScope().ExitOnError();
168 
169   Node* modify = MakeModify(root, "M");
170   Node* write = MakeWrite(root, "W");
171 
172   root.graph()->AddControlEdge(write, modify);
173 
174   std::vector<std::pair<int, int>> incompatible_pairs;
175   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
176 
177   ASSERT_EQ(incompatible_pairs.size(), 1);
178   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
179   EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
180 }
181 
TEST(ResourceOperationSafetyAnalysisTest,ReadModifyWrite)182 TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
183   Scope root = Scope::NewRootScope().ExitOnError();
184 
185   Node* read = MakeRead(root, "R");
186   Node* modify = MakeModify(root, "M");
187   Node* write = MakeWrite(root, "W");
188 
189   root.graph()->AddControlEdge(read, modify);
190   root.graph()->AddControlEdge(modify, write);
191 
192   std::vector<std::pair<int, int>> incompatible_pairs;
193   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
194 
195   EXPECT_EQ(incompatible_pairs.size(), 0);
196 }
197 
TEST(ResourceOperationSafetyAnalysisTest,WriteModifyRead)198 TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
199   Scope root = Scope::NewRootScope().ExitOnError();
200 
201   Node* read = MakeRead(root, "R");
202   Node* modify = MakeModify(root, "M");
203   Node* write = MakeWrite(root, "W");
204 
205   root.graph()->AddControlEdge(write, modify);
206   root.graph()->AddControlEdge(modify, read);
207 
208   std::vector<std::pair<int, int>> incompatible_pairs;
209   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
210 
211   ASSERT_EQ(incompatible_pairs.size(), 3);
212 
213   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
214   std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
215   std::pair<int, int> write_read_pair = {write->id(), read->id()};
216   EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
217   EXPECT_EQ(incompatible_pairs[1], write_read_pair);
218   EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
219 }
220 
TEST(ResourceOperationSafetyAnalysisTest,WriteReadModify)221 TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
222   Scope root = Scope::NewRootScope().ExitOnError();
223 
224   Node* read = MakeRead(root, "R");
225   Node* modify = MakeModify(root, "M");
226   Node* write = MakeWrite(root, "W");
227 
228   root.graph()->AddControlEdge(write, read);
229   root.graph()->AddControlEdge(read, modify);
230 
231   std::vector<std::pair<int, int>> incompatible_pairs;
232   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
233 
234   ASSERT_EQ(incompatible_pairs.size(), 2);
235 
236   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
237   std::pair<int, int> write_read_pair = {write->id(), read->id()};
238   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
239   EXPECT_EQ(incompatible_pairs[1], write_modify_pair);
240 }
241 
CreateFunctionDefLibWithConstFunction(const string & name)242 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
243   FunctionDefLibrary flib_def;
244   FunctionDef func = FunctionDefHelper::Create(
245       /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
246       /*attr_def*/
247       {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
248       /*ret_def=*/{{"out", "out:output:0"}});
249   *flib_def.add_function() = std::move(func);
250   return flib_def;
251 }
252 
MakeCall(Graph * graph,const string & callee_name,const string & node_name,Status * status)253 Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
254                Status* status) {
255   NodeDef call_node;
256   call_node.set_name(node_name);
257   call_node.set_op(callee_name);
258   return graph->AddNode(call_node, status);
259 }
260 
TEST(ResourceOperationSafetyAnalysisTest,CallRead)261 TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
262   Scope root = Scope::NewRootScope().ExitOnError();
263 
264   FunctionDefLibrary flib_def =
265       CreateFunctionDefLibWithConstFunction("Const_func");
266   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
267 
268   Node* read = MakeRead(root, "R");
269   Status status;
270   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
271   TF_ASSERT_OK(status);
272 
273   root.graph()->AddControlEdge(call, read);
274 
275   std::vector<std::pair<int, int>> incompatible_pairs;
276   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
277 
278   ASSERT_EQ(incompatible_pairs.size(), 1);
279   std::pair<int, int> call_read_edge = {call->id(), read->id()};
280   EXPECT_EQ(incompatible_pairs[0], call_read_edge);
281 }
282 
TEST(ResourceOperationSafetyAnalysisTest,ReadCall)283 TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
284   Scope root = Scope::NewRootScope().ExitOnError();
285 
286   FunctionDefLibrary flib_def =
287       CreateFunctionDefLibWithConstFunction("Const_func");
288   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
289 
290   Node* read = MakeRead(root, "R");
291   Status status;
292   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
293   TF_ASSERT_OK(status);
294 
295   root.graph()->AddControlEdge(read, call);
296 
297   std::vector<std::pair<int, int>> incompatible_pairs;
298   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
299 
300   EXPECT_EQ(incompatible_pairs.size(), 0);
301 }
302 
TEST(ResourceOperationSafetyAnalysisTest,CallWrite)303 TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
304   Scope root = Scope::NewRootScope().ExitOnError();
305 
306   FunctionDefLibrary flib_def =
307       CreateFunctionDefLibWithConstFunction("Const_func");
308   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
309 
310   Node* write = MakeWrite(root, "W");
311   Status status;
312   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
313   TF_ASSERT_OK(status);
314 
315   root.graph()->AddControlEdge(call, write);
316 
317   std::vector<std::pair<int, int>> incompatible_pairs;
318   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
319 
320   EXPECT_EQ(incompatible_pairs.size(), 0);
321 }
322 
TEST(ResourceOperationSafetyAnalysisTest,WriteCall)323 TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
324   Scope root = Scope::NewRootScope().ExitOnError();
325 
326   FunctionDefLibrary flib_def =
327       CreateFunctionDefLibWithConstFunction("Const_func");
328   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
329 
330   Node* write = MakeWrite(root, "W");
331   Status status;
332   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
333   TF_ASSERT_OK(status);
334 
335   root.graph()->AddControlEdge(write, call);
336 
337   std::vector<std::pair<int, int>> incompatible_pairs;
338   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
339 
340   ASSERT_EQ(incompatible_pairs.size(), 1);
341   std::pair<int, int> write_call_edge = {write->id(), call->id()};
342   EXPECT_EQ(incompatible_pairs[0], write_call_edge);
343 }
344 
TEST(ResourceOperationSafetyAnalysisTest,SymbolicGradientRead)345 TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
346   Scope root = Scope::NewRootScope().ExitOnError();
347 
348   FunctionDefLibrary flib_def =
349       CreateFunctionDefLibWithConstFunction("Const_func");
350   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
351 
352   Node* read = MakeRead(root, "R");
353   NameAttrList fn;
354   fn.set_name("Const_func");
355   Node* symbolic_gradient =
356       ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
357                             /*Tout=*/{DT_FLOAT}, fn)
358           .output[0]
359           .node();
360 
361   root.graph()->AddControlEdge(symbolic_gradient, read);
362 
363   std::vector<std::pair<int, int>> incompatible_pairs;
364   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
365 
366   ASSERT_EQ(incompatible_pairs.size(), 1);
367   std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
368                                                      read->id()};
369   EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
370 }
371 
TEST(ResourceOperationSafetyAnalysisTest,WriteSymbolicGradient)372 TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
373   Scope root = Scope::NewRootScope().ExitOnError();
374 
375   FunctionDefLibrary flib_def =
376       CreateFunctionDefLibWithConstFunction("Const_func");
377   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
378 
379   Node* write = MakeWrite(root, "W");
380   NameAttrList fn;
381   fn.set_name("Const_func");
382   Node* symbolic_gradient =
383       ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
384                             /*Tout=*/{DT_FLOAT}, fn)
385           .output[0]
386           .node();
387 
388   root.graph()->AddControlEdge(write, symbolic_gradient);
389 
390   std::vector<std::pair<int, int>> incompatible_pairs;
391   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
392 
393   ASSERT_EQ(incompatible_pairs.size(), 1);
394   std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
395                                                       symbolic_gradient->id()};
396   EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
397 }
398 
TEST(ResourceOperationSafetyAnalysisTest,ChainOfOps)399 TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
400   Scope root = Scope::NewRootScope().ExitOnError();
401 
402   Node* write_0 = MakeWrite(root, "W0");
403   Node* neutral_0 = MakeNeutral(root, "N0");
404   Node* read_0 = MakeRead(root, "R0");
405   Node* write_1 = MakeWrite(root, "W1");
406   Node* neutral_1 = MakeNeutral(root, "N1");
407   Node* read_1 = MakeRead(root, "R1");
408 
409   root.graph()->AddControlEdge(write_0, neutral_0);
410   root.graph()->AddControlEdge(neutral_0, read_0);
411   root.graph()->AddControlEdge(read_0, write_1);
412   root.graph()->AddControlEdge(write_1, neutral_1);
413   root.graph()->AddControlEdge(neutral_1, read_1);
414 
415   std::vector<std::pair<int, int>> incompatible_pairs;
416   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
417 
418   ASSERT_EQ(incompatible_pairs.size(), 3);
419   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
420   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
421   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
422 
423   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
424   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
425   EXPECT_EQ(incompatible_pairs[2], write_1_read_1_pair);
426 }
427 
TEST(ResourceOperationSafetyAnalysisTest,DagOfOps)428 TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
429   Scope root = Scope::NewRootScope().ExitOnError();
430 
431   Node* write_0 = MakeWrite(root, "W0");
432   Node* write_1 = MakeWrite(root, "W1");
433   Node* neutral = MakeNeutral(root, "N");
434   Node* read_0 = MakeRead(root, "R0");
435   Node* read_1 = MakeRead(root, "R1");
436 
437   root.graph()->AddControlEdge(write_0, neutral);
438   root.graph()->AddControlEdge(write_1, neutral);
439   root.graph()->AddControlEdge(neutral, read_0);
440   root.graph()->AddControlEdge(neutral, read_1);
441 
442   std::vector<std::pair<int, int>> incompatible_pairs;
443   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
444 
445   ASSERT_EQ(incompatible_pairs.size(), 4);
446   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
447   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
448   std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
449   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
450 
451   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
452   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
453   EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
454   EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
455 }
456 
TEST(ResourceOperationSafetyAnalysisTest,DagOfOpsWithRepeatedPaths)457 TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
458   Scope root = Scope::NewRootScope().ExitOnError();
459 
460   Node* write_0 = MakeWrite(root, "W0");
461   Node* write_1 = MakeWrite(root, "W1");
462   Node* neutral = MakeNeutral(root, "N");
463   Node* read_0 = MakeRead(root, "R0");
464   Node* read_1 = MakeRead(root, "R1");
465 
466   root.graph()->AddControlEdge(write_0, neutral);
467   root.graph()->AddControlEdge(write_1, neutral);
468   root.graph()->AddControlEdge(neutral, read_0);
469   root.graph()->AddControlEdge(neutral, read_1);
470   root.graph()->AddControlEdge(write_1, read_1);
471 
472   std::vector<std::pair<int, int>> incompatible_pairs;
473   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
474 
475   ASSERT_EQ(incompatible_pairs.size(), 4);
476   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
477   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
478   std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
479   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
480 
481   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
482   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
483   EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
484   EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
485 }
486 
TEST(ResourceOperationSafetyAnalysisTest,Loop)487 TEST(ResourceOperationSafetyAnalysisTest, Loop) {
488   Scope root = Scope::NewRootScope().ExitOnError();
489 
490   Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
491   Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
492   Output enter_value =
493       ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
494   ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
495   ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
496   ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
497   Output next_iteration =
498       ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
499   TF_ASSERT_OK(
500       root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
501 
502   Node* write = MakeWrite(root, "W");
503   Node* read = MakeRead(root, "R");
504 
505   root.graph()->AddControlEdge(iv.output.node(), write);
506   root.graph()->AddControlEdge(write, read);
507   root.graph()->AddControlEdge(read, next_iteration.node());
508 
509   std::vector<std::pair<int, int>> incompatible_pairs;
510   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
511 
512   ASSERT_EQ(incompatible_pairs.size(), 1);
513 
514   std::pair<int, int> write_read_pair = {write->id(), read->id()};
515   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
516 }
517 
IsResourceArgDef(const OpDef::ArgDef & arg_def)518 bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
519   return arg_def.type() == DT_RESOURCE;
520 }
521 }  // namespace
522 }  // namespace tensorflow
523