1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/control_flow.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/while_loop.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace {
LessThanTenCond(const Scope & scope,const std::vector<Output> & inputs,Output * output)29 Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
30                        Output* output) {
31   *output = ops::Less(scope, inputs[0], 10);
32   return scope.status();
33 }
34 
AddOneBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)35 Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
36                   std::vector<Output>* outputs) {
37   outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
38   return scope.status();
39 }
40 
NestedLoopBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)41 Status NestedLoopBody(const Scope& scope, const std::vector<Output>& inputs,
42                       std::vector<Output>* outputs) {
43   return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs,
44                              LessThanTenCond, AddOneBody, "inner_loop",
45                              outputs);
46 }
47 
TEST(ValidateControlFlowTest,InputsFromDifferentFrames)48 TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
49   Scope scope = Scope::NewRootScope().ExitOnError();
50   std::vector<Output> inputs;
51   inputs.push_back(ops::Placeholder(scope, DT_INT32));
52   std::vector<Output> outputs;
53   TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs,
54                                    LessThanTenCond, NestedLoopBody,
55                                    "outer_loop", &outputs));
56   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
57   TF_ASSERT_OK(scope.ToGraph(graph.get()));
58   // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame
59   // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'.
60   std::vector<ControlFlowInfo> info;
61   Status status = BuildControlFlowInfo(graph.get(), &info);
62   EXPECT_FALSE(status.ok());
63   EXPECT_TRUE(absl::StrContains(status.error_message(),
64                                 "has inputs from different frames"))
65       << status.error_message();
66   EXPECT_TRUE(absl::StrContains(status.error_message(),
67                                 "{{node outer/body/inner/Merge}}"))
68       << status.error_message();
69   EXPECT_TRUE(absl::StrContains(status.error_message(),
70                                 "{{node outer/body/inner/Enter}}"))
71       << status.error_message();
72   EXPECT_TRUE(
73       absl::StrContains(status.error_message(), "{{node outer/Switch}}"))
74       << status.error_message();
75 }
76 
TEST(ValidateControlFlowTest,MismatchedParentFrames)77 TEST(ValidateControlFlowTest, MismatchedParentFrames) {
78   Scope scope = Scope::NewRootScope().ExitOnError();
79   std::vector<Output> inputs;
80   inputs.push_back(ops::Placeholder(scope, DT_INT32));
81   std::vector<Output> outputs;
82   TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
83                                    "test_loop", &outputs));
84   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
85   TF_ASSERT_OK(scope.ToGraph(graph.get()));
86   Node* enter_1 = nullptr;
87   for (Node* node : graph->op_nodes()) {
88     if (IsEnter(node)) {
89       enter_1 = node;
90     }
91   }
92   ASSERT_TRUE(enter_1 != nullptr);
93 
94   NodeDef enter;
95   enter.set_name("Enter2");
96   enter.set_op("Enter");
97   (*enter.mutable_attr())["T"].set_type(DT_INT32);
98   (*enter.mutable_attr())["frame_name"].set_s("test_loop");
99   *enter.add_input() = "Enter";
100   Status status;
101   Node* enter_2 = graph->AddNode(enter, &status);
102   TF_ASSERT_OK(status);
103   graph->AddControlEdge(enter_1, enter_2);
104 
105   // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop")
106   // For node 'Enter', the parent frame of "test_loop" is empty.
107   // For node 'Enter2', the parent frame of "test_loop" is "test_loop".
108   std::vector<ControlFlowInfo> info;
109   status = BuildControlFlowInfo(graph.get(), &info);
110   EXPECT_FALSE(status.ok());
111   EXPECT_TRUE(
112       absl::StrContains(status.error_message(), "Mismatched parent frames"))
113       << status.error_message();
114   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Enter2}}"))
115       << status.error_message();
116 }
117 
TEST(ValidateControlFlowTest,TwoLoopCond)118 TEST(ValidateControlFlowTest, TwoLoopCond) {
119   // Test that one frame has at most one LoopCond node. This is necessary for
120   // functionalize control flow.
121   Scope scope = Scope::NewRootScope().ExitOnError();
122   std::vector<Output> inputs;
123   inputs.push_back(ops::Placeholder(scope, DT_INT32));
124   std::vector<Output> outputs;
125   TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
126                                    "test_loop", &outputs));
127   outputs.clear();
128   TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs,
129                                    LessThanTenCond, AddOneBody, "test_loop",
130                                    &outputs, false));
131   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
132   TF_ASSERT_OK(scope.ToGraph(graph.get()));
133   std::vector<ControlFlowInfo> info;
134   Status status = BuildControlFlowInfo(graph.get(), &info);
135   EXPECT_FALSE(status.ok());
136   EXPECT_TRUE(
137       absl::StrContains(status.error_message(), "more than one LoopCond node"))
138       << status.error_message();
139   EXPECT_TRUE(
140       absl::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
141       << status.error_message();
142   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node LoopCond}}"))
143       << status.error_message();
144 }
145 
146 }  // namespace
147 }  // namespace tensorflow
148