1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/ops/const_op.h"
16 #include "tensorflow/cc/ops/sendrecv_ops.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/tensor_testutil.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/core/platform/test_benchmark.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declarations so we don't need a public header.
32 Status SparsifyGather(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35 Status ReadTensorFromCheckpoint(
36 const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
37 const string& shape_and_slice, Tensor* tensor);
38
39 class SparsifyGatherTest : public ::testing::Test {
40 protected:
CreateNode(const StringPiece name,const StringPiece op,const std::vector<NodeDef * > & inputs,GraphDef * graph_def,bool control_dep=false)41 NodeDef* CreateNode(const StringPiece name, const StringPiece op,
42 const std::vector<NodeDef*>& inputs, GraphDef* graph_def,
43 bool control_dep = false) {
44 NodeDef* node_def = graph_def->add_node();
45 node_def->set_name(string(name));
46 node_def->set_op(string(op));
47 if (!control_dep) {
48 std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
49 node_def->add_input(input->name());
50 });
51 } else {
52 std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
53 node_def->add_input(strings::StrCat("^", input->name()));
54 });
55 }
56 return node_def;
57 }
58
MakeGather(StringPiece name,bool gather_v2,NodeDef * params,NodeDef * indices,GraphDef * graph_def)59 void MakeGather(StringPiece name, bool gather_v2, NodeDef* params,
60 NodeDef* indices, GraphDef* graph_def) {
61 if (gather_v2) {
62 NodeDef* axis_node =
63 CreateNode(strings::StrCat(name, "_axis"), "Const", {}, graph_def);
64 Tensor axis_t(DT_INT32, TensorShape({}));
65 axis_t.scalar<int32>()() = 0;
66 SetNodeTensorAttr<int32>("value", axis_t, axis_node);
67 CreateNode(name, "GatherV2", {params, indices, axis_node}, graph_def);
68 } else {
69 CreateNode(name, "Gather", {params, indices}, graph_def);
70 }
71 }
72
TestSinglePartition(bool gather_v2,bool include_shared_init,bool test_variable,bool test_kept_concat,const string & shared_init_name="group_deps")73 void TestSinglePartition(bool gather_v2, bool include_shared_init,
74 bool test_variable, bool test_kept_concat,
75 const string& shared_init_name = "group_deps") {
76 GraphDef graph_def;
77
78 const auto checkpoint_path =
79 io::JoinPath(testing::TmpDir(), "checkpoint_single");
80 // Build the graph.
81 NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
82 NodeDef* w_node;
83 NodeDef* zeros_const;
84 NodeDef* zeros_shape;
85 NodeDef* zeros_node;
86 NodeDef* assign_node;
87
88 Tensor weights(DT_FLOAT, TensorShape({4, 1}));
89 test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
90
91 if (!test_variable) {
92 w_node = CreateNode("w/part_1", "Const", {}, &graph_def);
93 SetNodeTensorAttr<float>("value", weights, w_node);
94 } else {
95 w_node = CreateNode("w/part_1", "VariableV2", {}, &graph_def);
96
97 zeros_shape = CreateNode("w/part_1/Initializer/zeros/shape_as_tensor",
98 "Const", {}, &graph_def);
99 zeros_const = CreateNode("w/part_1/Initializer/zeros/Const", "Const", {},
100 &graph_def);
101 zeros_node = CreateNode("w/part_1/Initializer/zeros", "Fill",
102 {zeros_shape, zeros_const}, &graph_def);
103 assign_node = CreateNode("w/part_1/Assign", "Assign",
104 {w_node, zeros_node}, &graph_def);
105
106 NodeDef* save_const_node =
107 CreateNode("save/Const", "Const", {}, &graph_def);
108
109 Tensor tensor_names_values(DT_STRING, TensorShape({1}));
110 test::FillValues<string>(&tensor_names_values, {"w"});
111 NodeDef* tensor_names_node =
112 CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
113 SetNodeTensorAttr<string>("value", tensor_names_values,
114 tensor_names_node);
115
116 NodeDef* tensor_shapes_slices_node = CreateNode(
117 "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
118 Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
119 shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
120 SetNodeTensorAttr<string>("value", shapes_slices_val,
121 tensor_shapes_slices_node);
122
123 NodeDef* restore_node = CreateNode(
124 "save/RestoreV2", "RestoreV2",
125 {save_const_node, tensor_names_node, tensor_shapes_slices_node},
126 &graph_def);
127 CreateNode("save/Assign", "Assign", {w_node, restore_node}, &graph_def);
128
129 BundleWriter writer(Env::Default(), checkpoint_path);
130 TF_ASSERT_OK(writer.Add("w", weights));
131 TF_ASSERT_OK(writer.Finish());
132 }
133 SetNodeAttr("dtype", DT_FLOAT, w_node);
134
135 NodeDef* identity_node =
136 CreateNode("w/read", "Identity", {w_node}, &graph_def);
137 MakeGather("gather", gather_v2, identity_node, input_node, &graph_def);
138 if (include_shared_init) {
139 if (!test_variable) {
140 CreateNode(shared_init_name, "NoOp", {}, &graph_def);
141 } else {
142 CreateNode(shared_init_name, "NoOp", {assign_node}, &graph_def, true);
143 }
144 }
145
146 NodeDef* concat_axis_node =
147 CreateNode("linear/concat/axis", "Const", {}, &graph_def);
148 NodeDef* concat_input_node =
149 CreateNode("concat/input/node", "Const", {}, &graph_def);
150 NodeDef* concat_node = nullptr;
151 if (!test_kept_concat) {
152 concat_node = CreateNode(
153 "concat/node", "ConcatV2",
154 {identity_node, concat_input_node, concat_axis_node}, &graph_def);
155 SetNodeAttr("N", 2, concat_node);
156 } else {
157 NodeDef* concat_input_node_2 =
158 CreateNode("concat/input/node_2", "Const", {}, &graph_def);
159 concat_node = CreateNode("concat/node", "ConcatV2",
160 {identity_node, concat_input_node,
161 concat_input_node_2, concat_axis_node},
162 &graph_def);
163 SetNodeAttr("N", 3, concat_node);
164 }
165
166 // Run the op.
167 GraphDef result;
168 TransformFuncContext context;
169 context.input_names = {"ids"};
170 context.output_names = {"gather"};
171 if (test_variable) {
172 context.params["input_checkpoint"] = {checkpoint_path};
173 }
174 if (shared_init_name != "group_deps") {
175 context.params["group_init_node"] = {shared_init_name};
176 }
177 TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
178
179 // Validation begins.
180 std::map<string, const NodeDef*> node_lookup;
181 MapNamesToNodes(result, &node_lookup);
182
183 // Check nodes.
184 EXPECT_EQ(0,
185 node_lookup.count("w/part_1/Initializer/zeros/shape_as_tensor"));
186 EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros/Const"));
187 EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros"));
188 EXPECT_EQ(0, node_lookup.count("w/part_1/Assign"));
189
190 EXPECT_EQ(1, node_lookup.count("ids"));
191 EXPECT_EQ("Const", node_lookup.at("ids")->op());
192
193 EXPECT_EQ(1, node_lookup.count("concat/node"));
194
195 if (!test_kept_concat) {
196 EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
197 EXPECT_EQ("Identity", node_lookup.at("concat/node")->op());
198 EXPECT_EQ(1, node_lookup.at("concat/node")->input_size());
199 EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
200 } else {
201 EXPECT_EQ(1, node_lookup.count("linear/concat/axis"));
202 EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op());
203 EXPECT_EQ(3, node_lookup.at("concat/node")->input_size());
204 EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
205 EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1));
206 EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2));
207 EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i());
208 }
209
210 EXPECT_EQ(1, node_lookup.count("w/part_1/indices"));
211 EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op());
212 Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
213 test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3});
214 test::ExpectTensorEqual<int64>(
215 expected_indices_tensor,
216 GetNodeTensorAttr(*(node_lookup.at("w/part_1/indices")), "value"));
217
218 EXPECT_EQ(1, node_lookup.count("w/part_1/values"));
219 EXPECT_EQ("Const", node_lookup.at("w/part_1/values")->op());
220 Tensor expected_values_tensor(DT_FLOAT, TensorShape({3}));
221 test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001});
222 test::ExpectTensorNear<float>(
223 expected_values_tensor,
224 GetNodeTensorAttr(*(node_lookup.at("w/part_1/values")), "value"), 1e-5);
225
226 EXPECT_EQ(1, node_lookup.count("w/part_1/HashTable"));
227 EXPECT_EQ("HashTable", node_lookup.at("w/part_1/HashTable")->op());
228
229 EXPECT_EQ(1, node_lookup.count("w/part_1/InitializeTable"));
230 EXPECT_EQ("InitializeTable",
231 node_lookup.at("w/part_1/InitializeTable")->op());
232
233 // Nodes in "gather" scope.
234 EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind"));
235 EXPECT_EQ("LookupTableFind",
236 node_lookup.at("gather/LookupTableFind")->op());
237
238 EXPECT_EQ(1, node_lookup.count("gather/Const"));
239 EXPECT_EQ("Const", node_lookup.at("gather/Const")->op());
240 Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({}));
241 test::FillValues<float>(&expected_gather_default_tensor, {0.0});
242 test::ExpectTensorNear<float>(
243 expected_gather_default_tensor,
244 GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5);
245
246 EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const"));
247 EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op());
248 Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({}));
249 test::FillValues<int32>(&expected_expand_dims_tensor, {-1});
250 test::ExpectTensorEqual<int32>(
251 expected_expand_dims_tensor,
252 GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")),
253 "value"));
254
255 EXPECT_EQ(1, node_lookup.count("gather"));
256 EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op());
257
258 EXPECT_EQ(1, node_lookup.count(shared_init_name));
259 EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op());
260
261 // Check connections
262 EXPECT_EQ("w/part_1/HashTable",
263 node_lookup.at("w/part_1/InitializeTable")->input(0));
264 EXPECT_EQ("w/part_1/indices",
265 node_lookup.at("w/part_1/InitializeTable")->input(1));
266 EXPECT_EQ("w/part_1/values",
267 node_lookup.at("w/part_1/InitializeTable")->input(2));
268
269 EXPECT_EQ("w/part_1/HashTable",
270 node_lookup.at("gather/LookupTableFind")->input(0));
271 EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1));
272 EXPECT_EQ("gather/Const",
273 node_lookup.at("gather/LookupTableFind")->input(2));
274
275 EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0));
276
277 // Check control dependency.
278 EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
279 node_lookup.at(shared_init_name)->input().end(),
280 "^w/part_1/InitializeTable"),
281 node_lookup.at(shared_init_name)->input().end());
282 EXPECT_EQ(1, node_lookup.at(shared_init_name)->input().size());
283 }
284
TestMultiPartition(bool gather_v2,bool include_shared_init,bool test_variable,const string & shared_init_name="group_deps")285 void TestMultiPartition(bool gather_v2, bool include_shared_init,
286 bool test_variable,
287 const string& shared_init_name = "group_deps") {
288 // The 'ids' node is served input for two 'Gather's.
289 GraphDef graph_def;
290
291 const auto checkpoint_path =
292 io::JoinPath(testing::TmpDir(), "checkpoint_multiple");
293 // Build Graph:
294 // Shared input node
295 NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
296
297 // Two partitions
298 NodeDef* w_node1;
299 NodeDef* w_node2;
300 NodeDef* zeros_const1;
301 NodeDef* zeros_shape1;
302 NodeDef* zeros_node1;
303 NodeDef* zeros_const2;
304 NodeDef* zeros_shape2;
305 NodeDef* zeros_node2;
306 NodeDef* assign_node1;
307 NodeDef* assign_node2;
308
309 Tensor weights(DT_FLOAT, TensorShape({4, 1}));
310 test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
311 if (!test_variable) {
312 w_node1 = CreateNode("w1/part_1", "Const", {}, &graph_def);
313 w_node2 = CreateNode("w2/part_1", "Const", {}, &graph_def);
314 SetNodeTensorAttr<float>("value", weights, w_node1);
315 SetNodeTensorAttr<float>("value", weights, w_node2);
316 } else {
317 NodeDef* save_const_node =
318 CreateNode("save/Const", "Const", {}, &graph_def);
319
320 NodeDef* tensor_names_node =
321 CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
322 Tensor tensor_names_values(DT_STRING, TensorShape({2}));
323 test::FillValues<string>(&tensor_names_values, {"w1", "w2"});
324 SetNodeTensorAttr<string>("value", tensor_names_values,
325 tensor_names_node);
326
327 NodeDef* tensor_shapes_slices_node = CreateNode(
328 "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
329 Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
330 shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
331 shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
332 SetNodeTensorAttr<string>("value", shapes_slices_val,
333 tensor_shapes_slices_node);
334
335 NodeDef* restore_node = CreateNode(
336 "save/RestoreV2", "RestoreV2",
337 {save_const_node, tensor_names_node, tensor_shapes_slices_node},
338 &graph_def);
339
340 w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
341
342 zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
343 "Const", {}, &graph_def);
344 zeros_const1 = CreateNode("w1/part_1/Initializer/zeros/Const", "Const",
345 {}, &graph_def);
346 zeros_node1 = CreateNode("w1/part_1/Initializer/zeros", "Fill",
347 {zeros_shape1, zeros_const1}, &graph_def);
348 assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
349 {w_node1, zeros_node1}, &graph_def);
350
351 CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def);
352
353 w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
354 zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
355 "Const", {}, &graph_def);
356 zeros_const2 = CreateNode("w2/part_1/Initializer/zeros/Const", "Const",
357 {}, &graph_def);
358 zeros_node2 = CreateNode("w2/part_1/Initializer/zeros", "Fill",
359 {zeros_shape2, zeros_const2}, &graph_def);
360 assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
361 {w_node2, zeros_node2}, &graph_def);
362
363 CreateNode("save/Assign_1", "Assign", {w_node2, restore_node},
364 &graph_def);
365
366 BundleWriter writer(Env::Default(), checkpoint_path);
367 TF_ASSERT_OK(writer.Add("w1", weights));
368 TF_ASSERT_OK(writer.Add("w2", weights));
369 TF_ASSERT_OK(writer.Finish());
370 }
371 SetNodeAttr("dtype", DT_FLOAT, w_node1);
372 SetNodeAttr("dtype", DT_FLOAT, w_node2);
373
374 NodeDef* identity_node1 =
375 CreateNode("w1/part_1/read", "Identity", {w_node1}, &graph_def);
376 NodeDef* identity_node2 =
377 CreateNode("w2/part_1/read", "Identity", {w_node2}, &graph_def);
378 MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def);
379 MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def);
380
381 NodeDef* concat_axis_node =
382 CreateNode("linear/concat/axis", "Const", {}, &graph_def);
383 NodeDef* concat_node = CreateNode(
384 "concat/node", "ConcatV2",
385 {identity_node1, identity_node2, concat_axis_node}, &graph_def);
386 SetNodeAttr("N", 2, concat_node);
387
388 // Shared init node
389 if (include_shared_init) {
390 if (!test_variable) {
391 CreateNode(shared_init_name, "NoOp", {}, &graph_def);
392 } else {
393 CreateNode(shared_init_name, "NoOp", {assign_node1, assign_node2},
394 &graph_def, true);
395 }
396 }
397
398 // Run the op.
399 GraphDef result;
400 TransformFuncContext context;
401 context.input_names = {"ids"};
402 context.output_names = {"gather1", "gather2"};
403 if (test_variable) {
404 context.params["input_checkpoint"] = {checkpoint_path};
405 }
406 if (shared_init_name != "group_deps") {
407 context.params["group_init_node"] = {shared_init_name};
408 }
409 TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
410
411 // Validation begins.
412 std::map<string, const NodeDef*> node_lookup;
413 MapNamesToNodes(result, &node_lookup);
414
415 // Check nodes.
416 EXPECT_EQ(0,
417 node_lookup.count("w1/part_1/Initializer/zeros/shape_as_tensor"));
418 EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros/Const"));
419 EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros"));
420 EXPECT_EQ(0, node_lookup.count("w1/part_1/Assign"));
421 EXPECT_EQ(0,
422 node_lookup.count("w2/part_1/Initializer/zeros/shape_as_tensor"));
423 EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros/Const"));
424 EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros"));
425 EXPECT_EQ(0, node_lookup.count("w2/part_1/Assign"));
426 EXPECT_EQ(1, node_lookup.count("ids"));
427 EXPECT_EQ("Const", node_lookup.at("ids")->op());
428
429 EXPECT_EQ(1, node_lookup.count(shared_init_name));
430 EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op());
431
432 EXPECT_EQ(1, node_lookup.count("w1/part_1/indices"));
433 EXPECT_EQ("Const", node_lookup.at("w1/part_1/indices")->op());
434 Tensor expected_indices_tensor1(DT_INT64, TensorShape({3}));
435 test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3});
436 test::ExpectTensorEqual<int64>(
437 expected_indices_tensor1,
438 GetNodeTensorAttr(*(node_lookup.at("w1/part_1/indices")), "value"));
439
440 EXPECT_EQ(1, node_lookup.count("w1/part_1/values"));
441 EXPECT_EQ("Const", node_lookup.at("w1/part_1/values")->op());
442 Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3}));
443 test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001});
444 test::ExpectTensorNear<float>(
445 expected_values_tensor1,
446 GetNodeTensorAttr(*(node_lookup.at("w1/part_1/values")), "value"),
447 1e-5);
448
449 EXPECT_EQ(1, node_lookup.count("w1/part_1/HashTable"));
450 EXPECT_EQ("HashTable", node_lookup.at("w1/part_1/HashTable")->op());
451
452 EXPECT_EQ(1, node_lookup.count("w1/part_1/InitializeTable"));
453 EXPECT_EQ("InitializeTable",
454 node_lookup.at("w1/part_1/InitializeTable")->op());
455
456 // Nodes in "gather1" scope.
457 EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind"));
458 EXPECT_EQ("LookupTableFind",
459 node_lookup.at("gather1/LookupTableFind")->op());
460
461 EXPECT_EQ(1, node_lookup.count("gather1/Const"));
462 EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op());
463 Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({}));
464 test::FillValues<float>(&expected_gather_default_tensor1, {0.0});
465 test::ExpectTensorNear<float>(
466 expected_gather_default_tensor1,
467 GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5);
468
469 EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const"));
470 EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op());
471 Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({}));
472 test::FillValues<int32>(&expected_expand_dims_tensor1, {-1});
473 test::ExpectTensorEqual<int32>(
474 expected_expand_dims_tensor1,
475 GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")),
476 "value"));
477
478 EXPECT_EQ(1, node_lookup.count("gather1"));
479 EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op());
480
481 EXPECT_EQ(1, node_lookup.count("w2/part_1/indices"));
482 EXPECT_EQ("Const", node_lookup.at("w2/part_1/indices")->op());
483 Tensor expected_indices_tensor2(DT_INT64, TensorShape({3}));
484 test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3});
485 test::ExpectTensorEqual<int64>(
486 expected_indices_tensor2,
487 GetNodeTensorAttr(*(node_lookup.at("w2/part_1/indices")), "value"));
488
489 EXPECT_EQ(1, node_lookup.count("w2/part_1/values"));
490 EXPECT_EQ("Const", node_lookup.at("w2/part_1/values")->op());
491 Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3}));
492 test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001});
493 test::ExpectTensorNear<float>(
494 expected_values_tensor2,
495 GetNodeTensorAttr(*(node_lookup.at("w2/part_1/values")), "value"),
496 1e-5);
497
498 EXPECT_EQ(1, node_lookup.count("w2/part_1/HashTable"));
499 EXPECT_EQ("HashTable", node_lookup.at("w2/part_1/HashTable")->op());
500
501 EXPECT_EQ(1, node_lookup.count("w2/part_1/InitializeTable"));
502 EXPECT_EQ("InitializeTable",
503 node_lookup.at("w2/part_1/InitializeTable")->op());
504
505 // Nodes in "gather2" scope.
506 EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind"));
507 EXPECT_EQ("LookupTableFind",
508 node_lookup.at("gather2/LookupTableFind")->op());
509
510 EXPECT_EQ(1, node_lookup.count("gather2/Const"));
511 EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op());
512 Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({}));
513 test::FillValues<float>(&expected_gather_default_tensor2, {0.0});
514 test::ExpectTensorNear<float>(
515 expected_gather_default_tensor2,
516 GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5);
517
518 EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const"));
519 EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op());
520 Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({}));
521 test::FillValues<int32>(&expected_expand_dims_tensor2, {-1});
522 test::ExpectTensorEqual<int32>(
523 expected_expand_dims_tensor2,
524 GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")),
525 "value"));
526
527 EXPECT_EQ(1, node_lookup.count("gather2"));
528 EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op());
529
530 // Check connections
531 EXPECT_EQ("w1/part_1/HashTable",
532 node_lookup.at("w1/part_1/InitializeTable")->input(0));
533 EXPECT_EQ("w1/part_1/indices",
534 node_lookup.at("w1/part_1/InitializeTable")->input(1));
535 EXPECT_EQ("w1/part_1/values",
536 node_lookup.at("w1/part_1/InitializeTable")->input(2));
537
538 EXPECT_EQ("w2/part_1/HashTable",
539 node_lookup.at("w2/part_1/InitializeTable")->input(0));
540 EXPECT_EQ("w2/part_1/indices",
541 node_lookup.at("w2/part_1/InitializeTable")->input(1));
542 EXPECT_EQ("w2/part_1/values",
543 node_lookup.at("w2/part_1/InitializeTable")->input(2));
544
545 EXPECT_EQ("w1/part_1/HashTable",
546 node_lookup.at("gather1/LookupTableFind")->input(0));
547 EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1));
548 EXPECT_EQ("gather1/Const",
549 node_lookup.at("gather1/LookupTableFind")->input(2));
550 EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0));
551
552 EXPECT_EQ("w2/part_1/HashTable",
553 node_lookup.at("gather2/LookupTableFind")->input(0));
554 EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1));
555 EXPECT_EQ("gather2/Const",
556 node_lookup.at("gather2/LookupTableFind")->input(2));
557 EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
558
559 EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
560 EXPECT_EQ(0, node_lookup.count("concat/node"));
561
562 // Check control deps.
563 EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size());
564 EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
565 node_lookup.at(shared_init_name)->input().end(),
566 "^w1/part_1/InitializeTable"),
567 node_lookup.at(shared_init_name)->input().end());
568
569 EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
570 node_lookup.at(shared_init_name)->input().end(),
571 "^w2/part_1/InitializeTable"),
572 node_lookup.at(shared_init_name)->input().end());
573 }
TestReadTensorSlice()574 void TestReadTensorSlice() {
575 const auto checkpoint_path =
576 io::JoinPath(testing::TmpDir(), "checkpoint_slice");
577
578 Tensor weights(DT_FLOAT, TensorShape({2, 1}));
579 test::FillValues<float>(&weights, {0.2, 0.000001});
580 BundleWriter writer(Env::Default(), checkpoint_path);
581 TF_ASSERT_OK(writer.AddSlice("w", TensorShape({4, 1}),
582 TensorSlice::ParseOrDie("0,2:0,1"), weights));
583 TF_ASSERT_OK(writer.Finish());
584
585 std::unique_ptr<BundleReader> reader(
586 new BundleReader(Env::Default(), checkpoint_path));
587
588 Tensor results;
589 TF_ASSERT_OK(
590 ReadTensorFromCheckpoint("w/part_0", reader, "4 1 0,2:0,1", &results));
591
592 test::ExpectTensorEqual<float>(weights, results);
593 }
594 };
595
TEST_F(SparsifyGatherTest,TestSinglePartition)596 TEST_F(SparsifyGatherTest, TestSinglePartition) {
597 TestSinglePartition(false, false, false, false);
598 TestSinglePartition(false, true, false, false);
599 TestSinglePartition(true, false, false, false);
600 TestSinglePartition(true, true, false, false);
601 TestSinglePartition(false, false, true, false);
602 TestSinglePartition(false, true, true, false);
603 TestSinglePartition(true, false, true, false);
604 TestSinglePartition(true, true, true, false);
605 TestSinglePartition(false, true, false, false, "shared_inits");
606 TestSinglePartition(true, true, false, false, "shared_inits");
607 TestSinglePartition(false, true, true, false, "shared_inits");
608 TestSinglePartition(true, true, true, false, "shared_inits");
609
610 TestSinglePartition(false, false, false, true);
611 TestSinglePartition(false, true, false, true);
612 TestSinglePartition(true, false, false, true);
613 TestSinglePartition(true, true, false, true);
614 TestSinglePartition(false, false, true, true);
615 TestSinglePartition(false, true, true, true);
616 TestSinglePartition(true, false, true, true);
617 TestSinglePartition(true, true, true, true);
618 TestSinglePartition(false, true, false, true, "shared_inits");
619 TestSinglePartition(true, true, false, true, "shared_inits");
620 TestSinglePartition(false, true, true, true, "shared_inits");
621 TestSinglePartition(true, true, true, true, "shared_inits");
622 }
623
TEST_F(SparsifyGatherTest,TestMultiPartition)624 TEST_F(SparsifyGatherTest, TestMultiPartition) {
625 TestMultiPartition(false, false, false);
626 TestMultiPartition(false, true, false);
627 TestMultiPartition(true, false, false);
628 TestMultiPartition(true, true, false);
629 TestMultiPartition(false, false, true);
630 TestMultiPartition(false, true, true);
631 TestMultiPartition(true, false, true);
632 TestMultiPartition(true, true, true);
633 TestMultiPartition(false, true, false, "shared_inits");
634 TestMultiPartition(true, true, false, "shared_inits");
635 TestMultiPartition(false, true, true, "shared_inits");
636 TestMultiPartition(true, true, true, "shared_inits");
637 }
638
TEST_F(SparsifyGatherTest,TestTensorSlice)639 TEST_F(SparsifyGatherTest, TestTensorSlice) { TestReadTensorSlice(); }
640
641 } // namespace graph_transforms
642 } // namespace tensorflow
643