1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/colocation.h"
17
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24 namespace grappler {
25
26 class ColocationTest : public ::testing::Test {};
27
VerifyNodeHasColocation(const NodeDef & ndef,const string & coloc)28 bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) {
29 if (ndef.attr().empty()) {
30 return false;
31 }
32 if (ndef.attr().find("_class") == ndef.attr().end()) {
33 return false;
34 }
35 return ndef.attr().at("_class").list().s(0) == coloc;
36 }
37
TEST(ColocationTest,ReassignColocation_SingleNode)38 TEST(ColocationTest, ReassignColocation_SingleNode) {
39 // Node A colocates with B, but node B is not in the graph.
40 // A
41 // |
42 // |
43 // [B]
44
45 NodeDef ndef;
46 const Status status =
47 NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef);
48 TF_EXPECT_OK(status);
49 GraphDef gdef = test::function::GDef({ndef});
50
51 EXPECT_EQ(1, gdef.node_size());
52 EXPECT_EQ(1, gdef.node(0).attr_size());
53
54 ReassignColocation(&gdef);
55
56 // Validates that node A's colocation info is cleared.
57 EXPECT_EQ(1, gdef.node_size());
58 EXPECT_EQ(0, gdef.node(0).attr_size());
59 }
60
TEST(ColocationTest,ReassignColocation_MultiNode_SingleGroup)61 TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) {
62 // Node A, B, C colocate with X. D colocates with C. E colocates with D.
63 // Node X is not in the graph.
64 // A B C---D---E
65 // | | |
66 // | | |
67 // +--[X]--+
68 // After re-assign of colocation, A, B, C, D should colocate with E.
69 // A B C D
70 // | | | |
71 // | | | |
72 // +---+-E-+---+
73
74 NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e;
75 Status status =
76 NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
77 TF_EXPECT_OK(status);
78 status =
79 NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
80 TF_EXPECT_OK(status);
81 status =
82 NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
83 TF_EXPECT_OK(status);
84 status =
85 NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
86 TF_EXPECT_OK(status);
87 status =
88 NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
89 TF_EXPECT_OK(status);
90 GraphDef gdef =
91 test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e});
92
93 EXPECT_EQ(5, gdef.node_size());
94 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
95 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
96 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
97 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
98 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
99
100 ReassignColocation(&gdef);
101
102 EXPECT_EQ(5, gdef.node_size());
103 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
104 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
105 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
106 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
107 EXPECT_EQ(0, gdef.node(4).attr_size()); // E
108 }
109
TEST(ColocationTest,ReassignColocation_MultiNode_MultiGroup)110 TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) {
111 // Before re-assign:
112 // Node A, B, C colocate with X. D colocates with C. E colocates with D.
113 // Node U, V colocates with W. Node X, W are not in the graph:
114 // A B C---D---E
115 // | | |
116 // | | |
117 // +--[X]--+
118 //
119 // U V
120 // | |
121 // | |
122 // +--[W]--+
123 //
124 // After re-assign:
125 // A, B, C, D should colocate with E. U should colocate with V.
126 // A B C D
127 // | | | |
128 // | | | |
129 // +---+-E-+---+
130 //
131 // U
132 // |
133 // |
134 // V
135
136 NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v;
137 Status status =
138 NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
139 TF_EXPECT_OK(status);
140 status =
141 NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
142 TF_EXPECT_OK(status);
143 status =
144 NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
145 TF_EXPECT_OK(status);
146 status =
147 NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
148 TF_EXPECT_OK(status);
149 status =
150 NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
151 TF_EXPECT_OK(status);
152 status =
153 NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u);
154 TF_EXPECT_OK(status);
155 status =
156 NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v);
157 TF_EXPECT_OK(status);
158 GraphDef gdef = test::function::GDef(
159 {ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v});
160
161 EXPECT_EQ(7, gdef.node_size());
162 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
163 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
164 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
165 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
166 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
167 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W")); // U
168 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W")); // V
169
170 ReassignColocation(&gdef);
171
172 EXPECT_EQ(7, gdef.node_size());
173 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
174 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
175 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
176 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
177 EXPECT_EQ(0, gdef.node(4).attr_size()); // E
178 EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V")); // U
179 EXPECT_EQ(0, gdef.node(6).attr_size()); // V
180 }
181
182 } // namespace grappler
183 } // namespace tensorflow
184