1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/colocation.h"
17 
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 
26 class ColocationTest : public ::testing::Test {};
27 
VerifyNodeHasColocation(const NodeDef & ndef,const string & coloc)28 bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) {
29   if (ndef.attr().empty()) {
30     return false;
31   }
32   if (ndef.attr().find("_class") == ndef.attr().end()) {
33     return false;
34   }
35   return ndef.attr().at("_class").list().s(0) == coloc;
36 }
37 
TEST(ColocationTest,ReassignColocation_SingleNode)38 TEST(ColocationTest, ReassignColocation_SingleNode) {
39   // Node A colocates with B, but node B is not in the graph.
40   //   A
41   //   |
42   //   |
43   //  [B]
44 
45   NodeDef ndef;
46   const Status status =
47       NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef);
48   TF_EXPECT_OK(status);
49   GraphDef gdef = test::function::GDef({ndef});
50 
51   EXPECT_EQ(1, gdef.node_size());
52   EXPECT_EQ(1, gdef.node(0).attr_size());
53 
54   ReassignColocation(&gdef);
55 
56   // Validates that node A's colocation info is cleared.
57   EXPECT_EQ(1, gdef.node_size());
58   EXPECT_EQ(0, gdef.node(0).attr_size());
59 }
60 
TEST(ColocationTest,ReassignColocation_MultiNode_SingleGroup)61 TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) {
62   // Node A, B, C colocate with X. D colocates with C. E colocates with D.
63   // Node X is not in the graph.
64   //  A   B   C---D---E
65   //  |   |   |
66   //  |   |   |
67   //  +--[X]--+
68   // After re-assign of colocation, A, B, C, D should colocate with E.
69   // A   B   C   D
70   // |   |   |   |
71   // |   |   |   |
72   // +---+-E-+---+
73 
74   NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e;
75   Status status =
76       NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
77   TF_EXPECT_OK(status);
78   status =
79       NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
80   TF_EXPECT_OK(status);
81   status =
82       NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
83   TF_EXPECT_OK(status);
84   status =
85       NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
86   TF_EXPECT_OK(status);
87   status =
88       NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
89   TF_EXPECT_OK(status);
90   GraphDef gdef =
91       test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e});
92 
93   EXPECT_EQ(5, gdef.node_size());
94   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X"));  // A
95   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X"));  // B
96   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X"));  // C
97   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C"));  // D
98   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D"));  // E
99 
100   ReassignColocation(&gdef);
101 
102   EXPECT_EQ(5, gdef.node_size());
103   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E"));  // A
104   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E"));  // B
105   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E"));  // C
106   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E"));  // D
107   EXPECT_EQ(0, gdef.node(4).attr_size());                        // E
108 }
109 
TEST(ColocationTest,ReassignColocation_MultiNode_MultiGroup)110 TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) {
111   // Before re-assign:
112   // Node A, B, C colocate with X. D colocates with C. E colocates with D.
113   // Node U, V colocates with W. Node X, W are not in the graph:
114   //  A   B   C---D---E
115   //  |   |   |
116   //  |   |   |
117   //  +--[X]--+
118   //
119   //  U       V
120   //  |       |
121   //  |       |
122   //  +--[W]--+
123   //
124   // After re-assign:
125   // A, B, C, D should colocate with E. U should colocate with V.
126   // A   B   C   D
127   // |   |   |   |
128   // |   |   |   |
129   // +---+-E-+---+
130   //
131   // U
132   // |
133   // |
134   // V
135 
136   NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v;
137   Status status =
138       NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
139   TF_EXPECT_OK(status);
140   status =
141       NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
142   TF_EXPECT_OK(status);
143   status =
144       NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
145   TF_EXPECT_OK(status);
146   status =
147       NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
148   TF_EXPECT_OK(status);
149   status =
150       NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
151   TF_EXPECT_OK(status);
152   status =
153       NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u);
154   TF_EXPECT_OK(status);
155   status =
156       NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v);
157   TF_EXPECT_OK(status);
158   GraphDef gdef = test::function::GDef(
159       {ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v});
160 
161   EXPECT_EQ(7, gdef.node_size());
162   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X"));  // A
163   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X"));  // B
164   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X"));  // C
165   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C"));  // D
166   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D"));  // E
167   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W"));  // U
168   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W"));  // V
169 
170   ReassignColocation(&gdef);
171 
172   EXPECT_EQ(7, gdef.node_size());
173   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E"));  // A
174   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E"));  // B
175   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E"));  // C
176   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E"));  // D
177   EXPECT_EQ(0, gdef.node(4).attr_size());                        // E
178   EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V"));  // U
179   EXPECT_EQ(0, gdef.node(6).attr_size());                        // V
180 }
181 
182 }  // namespace grappler
183 }  // namespace tensorflow
184