1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <memory>
18 #include <unordered_map>
19 
20 #include "tensorflow/c/checkpoint_reader.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/graph/graph_constructor.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/graph/subgraph.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/public/session.h"
28 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
30 
31 namespace tensorflow {
32 using str_util::Join;
33 using str_util::Split;
34 using str_util::StringReplace;
35 using strings::StrCat;
36 
37 namespace graph_transforms {
38 
39 // Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
40 // non-zero tensor content.
SparsifyWeights(const Tensor & tensor,Tensor * indices_tensor,Tensor * values_tensor)41 Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
42                        Tensor* values_tensor) {
43   if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
44     return tensorflow::errors::FailedPrecondition(
45         "Transform only applicable to subgraph with 'Const' with "
46         "tensor of shape [N, 1]. But instead get shape ",
47         tensor.shape().DebugString(), ".");
48   }
49 
50   auto flat = tensor.flat<float>();
51   std::vector<int64> indices;
52   std::vector<float> values;
53 
54   for (int64 i = 0; i < flat.size(); i++) {
55     float val = flat(i);
56     if (std::abs(val) >= 1.0e-5) {
57       indices.push_back(i);
58       values.push_back(val);
59     }
60   }
61 
62   // During model initialization, InitializeTableOp makes use of
63   // KeyValueTensorIterator, which does not accept empty keys or values.
64   // Consequently, adding a dummy pair of indices and values as a walkaround.
65   if (indices.empty() || values.empty()) {
66     indices.push_back(0);
67     values.push_back(0);
68   }
69   *indices_tensor = Tensor(DataTypeToEnum<int64>::value,
70                            {static_cast<int64>(indices.size())});
71   std::copy_n(indices.begin(), indices.size(),
72               indices_tensor->flat<int64>().data());
73 
74   *values_tensor =
75       Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
76   std::copy_n(values.begin(), values.size(),
77               values_tensor->flat<float>().data());
78 
79   return Status::OK();
80 }
81 
CreateConstNode(const Tensor & tensor,const string & name,NodeDef * node_def)82 void CreateConstNode(const Tensor& tensor, const string& name,
83                      NodeDef* node_def) {
84   node_def->set_op("Const");
85   node_def->set_name(name);
86   SetNodeTensorAttr<float>("value", tensor, node_def);
87 }
88 
GetMonolithicTensorKey(const string & tensor_slice_name)89 string GetMonolithicTensorKey(const string& tensor_slice_name) {
90   std::vector<string> names = Split(tensor_slice_name, "/");
91   if (str_util::StartsWith(names[names.size() - 1], "part_")) {
92     CHECK_GE(names.size(), 2);
93     names.pop_back();
94   }
95   return Join(names, "/");
96 }
97 
ObtainTensorSlice(const GraphDef & input_graph_def,const string & target_name,string * shape_slice_string)98 Status ObtainTensorSlice(const GraphDef& input_graph_def,
99                          const string& target_name,
100                          string* shape_slice_string) {
101   string restore_node_name;
102   for (const auto& node : input_graph_def.node()) {
103     std::vector<string> node_name_parts = Split(node.name(), "/");
104     if (node_name_parts.size() == 2 &&
105         str_util::StartsWith(node_name_parts[0], "save") &&
106         str_util::StartsWith(node_name_parts[1], "Assign") &&
107         node.input(0) == target_name) {
108       restore_node_name = node.input(1);
109       break;
110     }
111   }
112 
113   std::vector<string> restore_node_parts = Split(restore_node_name, ":");
114   CHECK_LE(restore_node_parts.size(), 2);
115   string tensor_names_node;
116   string shape_and_slices_node;
117   for (const auto& node : input_graph_def.node()) {
118     if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
119       tensor_names_node = node.input(1);
120       shape_and_slices_node = node.input(2);
121       break;
122     }
123   }
124 
125   int offset = -1;
126   for (const auto& node : input_graph_def.node()) {
127     if (node.name() == tensor_names_node) {
128       Tensor tensor_names_tensor;
129       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
130       const auto& tensor_names_value = tensor_names_tensor.flat<string>();
131       for (int i = 0; i < tensor_names_value.size(); i++) {
132         if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
133           offset = i;
134           break;
135         }
136       }
137     }
138   }
139   if (offset == -1) {
140     return errors::Internal("Unable to find RestoreV2 entry for variable: ",
141                             target_name);
142   }
143   for (const auto& node : input_graph_def.node()) {
144     if (node.name() == shape_and_slices_node) {
145       Tensor shape_and_slices_tensor;
146       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
147       const auto& shape_and_slices_value =
148           shape_and_slices_tensor.flat<string>();
149       *shape_slice_string = shape_and_slices_value(offset);
150       return Status::OK();
151     }
152   }
153   return errors::Internal("Unable to find slice for variable: ", target_name);
154 }
155 
ReadTensorFromCheckpoint(const string & tensor_name,const std::unique_ptr<BundleReader> & ckpt_reader,const string & shape_and_slice,Tensor * tensor)156 Status ReadTensorFromCheckpoint(
157     const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
158     const string& shape_and_slice, Tensor* tensor) {
159   if (ckpt_reader) {
160     TensorShape parsed_full_shape;
161     TensorSlice parsed_slice;
162     TensorShape parsed_slice_shape;
163 
164     bool get_slice = false;
165     if (!shape_and_slice.empty()) {
166       TF_RETURN_IF_ERROR(
167           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
168                                          &parsed_slice, &parsed_slice_shape));
169       get_slice = (parsed_full_shape != parsed_slice_shape);
170     }
171     if (get_slice) {
172       TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice(
173           GetMonolithicTensorKey(tensor_name), parsed_slice, tensor));
174     } else {
175       TF_RETURN_IF_ERROR(
176           ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor));
177     }
178     return Status::OK();
179   }
180   return errors::Internal("Checkpoint reader was not initialized. ");
181 }
182 
InitializeCheckpointReader(const TransformFuncContext & context,std::unique_ptr<BundleReader> * ckpt_reader)183 Status InitializeCheckpointReader(const TransformFuncContext& context,
184                                   std::unique_ptr<BundleReader>* ckpt_reader) {
185   if (context.params.count("input_checkpoint")) {
186     const string input_checkpoint = context.params.at("input_checkpoint")[0];
187     ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint));
188     TF_RETURN_IF_ERROR((*ckpt_reader)->status());
189   }
190   return Status::OK();
191 }
192 
ObtainVariableInfo(const GraphDef & input_graph_def,std::unique_ptr<std::unordered_map<string,string>> * shapes_and_slices)193 Status ObtainVariableInfo(
194     const GraphDef& input_graph_def,
195     std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) {
196   shapes_and_slices->reset(new std::unordered_map<string, string>());
197   for (const auto& node : input_graph_def.node()) {
198     if ((node.op() == "Variable") || (node.op() == "VariableV2")) {
199       string s;
200       TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s));
201       (**shapes_and_slices)[node.name()] = s;
202     }
203   }
204   return Status::OK();
205 }
206 
RemoveInputAtIndex(NodeDef * n,int index)207 Status RemoveInputAtIndex(NodeDef* n, int index) {
208   for (int i = index; i < n->input_size() - 1; i++) {
209     n->mutable_input()->SwapElements(i, i + 1);
210   }
211   n->mutable_input()->RemoveLast();
212   return Status::OK();
213 }
214 
RemoveNodeAtIndex(GraphDef * g,int index)215 Status RemoveNodeAtIndex(GraphDef* g, int index) {
216   for (int i = index; i < g->node_size() - 1; i++) {
217     g->mutable_node()->SwapElements(i, i + 1);
218   }
219   g->mutable_node()->RemoveLast();
220   return Status::OK();
221 }
222 
SparsifyGatherInternal(const GraphDef & input_graph_def,const std::unique_ptr<std::unordered_map<string,string>> & shapes_and_slices,const TransformFuncContext & context,const OpTypePattern & pattern,const std::unique_ptr<BundleReader> & ckpt_reader,GraphDef * output_graph_def)223 Status SparsifyGatherInternal(
224     const GraphDef& input_graph_def,
225     const std::unique_ptr<std::unordered_map<string, string> >&
226         shapes_and_slices,
227     const TransformFuncContext& context, const OpTypePattern& pattern,
228     const std::unique_ptr<BundleReader>& ckpt_reader,
229     GraphDef* output_graph_def) {
230   string group_init_node = "group_deps";
231   if (context.params.count("group_init_node")) {
232     group_init_node = context.params.at("group_init_node")[0];
233   }
234   GraphDef current_graph_def = input_graph_def;
235   bool any_match_found = false;
236 
237   // Populate references.
238   std::unordered_map<string, int> refs;
239   for (const auto& node : current_graph_def.node()) {
240     for (const auto& input : node.input()) {
241       auto parsed_input = StringReplace(input, "^", "", true);
242       refs[parsed_input] += 1;
243     }
244   }
245 
246   // The subgraphs may have overlapping components, therefore GraphMatcher
247   // doesn't return all subgraphs in one round -- this has to be multi-round
248   // update.
249   do {
250     any_match_found = false;
251     GraphDef replaced_graph_def = current_graph_def;
252     std::vector<string> init_table_node_names;
253     std::vector<string> removed_node_names;
254 
255     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
256         current_graph_def, pattern,
257         [&ckpt_reader, &any_match_found, &init_table_node_names,
258          &shapes_and_slices, &removed_node_names,
259          &refs](const NodeMatch& match, const std::set<string>& input_nodes,
260                 const std::set<string>& output_nodes,
261                 std::vector<NodeDef>* new_nodes) {
262           any_match_found = true;
263 
264           // The captured subgraph should be of the following pattern:
265           // Const --> Identity --> Gather --> ...
266           //                          ^
267           //                          |
268           //                        (ids)
269           //
270           // After transform, it becomes:
271           //                   --> NoOp(group_deps)
272           //                   |
273           // Const --> InitializeTable --> HashTable
274           //                   ^              |
275           //                   |              |
276           // Const -------------              |
277           //                                  v
278           //               (ids) ---> LookupTableFind <--- Const(default)
279           //                                  |
280           //                                  v
281           //                                 ...
282 
283           // clang-format off
284           // For each subgraph, do the following
285           // 1. Sparsify the `Const`, creating two `Const`, for hashtable
286           // key/val.
287           // 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
288           // 3. Create a `HashTable` op connecting to `InitializeTable` op.
289           // 4. Replace the `Gather` with a `LookupTableFind` op.
290           // 5. Connect the `LookupTableFind` with
291           //    a. `HashTable`
292           //    b. `Gather`'s ids input
293           //    c. a `default_val` arg, valued at 0
294           // clang-format on
295           const NodeDef& gather_node = match.node;
296 
297           // GatherV2 adds an "axis" parameter. sparsify_gather only supports
298           // axis 0 gathers.
299           if (gather_node.op() == "GatherV2") {
300             // Per the OpTypePattern, the 3rd input to Gather must be a Const.
301             const NodeDef& axis_node = match.inputs[2].node;
302 
303             Tensor axis_t;
304             TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t));
305             int64 axis = 0;
306             if (axis_t.dtype() == DT_INT32) {
307               axis = axis_t.scalar<int32>()();
308             } else if (axis_t.dtype() == DT_INT64) {
309               axis = axis_t.scalar<int64>()();
310             } else {
311               return tensorflow::errors::FailedPrecondition(
312                   "Gather axis was not int32 or int64.");
313             }
314 
315             if (axis != 0) {
316               return tensorflow::errors::FailedPrecondition(
317                   "Transform only applicable to subgraph with GatherV2 over "
318                   "axis 0. Found axis ",
319                   axis, ".");
320             }
321           }
322 
323           const NodeDef& weights_node = match.inputs[0].inputs[0].node;
324 
325           DataType data_type;
326           TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type));
327           if (data_type != DT_FLOAT) {
328             return tensorflow::errors::FailedPrecondition(
329                 "Transform only applicable to subgraph with 'Const',"
330                 "'Variable', or 'VariableV2' of dtype "
331                 "'DT_FLOAT'. Found '" +
332                     weights_node.op() + "' with name '",
333                 weights_node.name(), "' and dtype '", data_type, "'.");
334           }
335 
336           Tensor weight;
337           if (weights_node.op() == "Const") {
338             weight = GetNodeTensorAttr(weights_node, "value");
339           } else {
340             TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
341                 weights_node.name(), ckpt_reader,
342                 (*shapes_and_slices)[weights_node.name()], &weight));
343           }
344           // Add both both weight and identity node names.
345           removed_node_names.push_back(weights_node.name());
346           removed_node_names.push_back(match.inputs[0].node.name());
347           for (auto input_node : match.inputs[0].node.input()) {
348             auto parsed_input = StringReplace(input_node, "^", "", true);
349             refs[parsed_input]--;
350           }
351           Tensor indices_tensor;
352           Tensor values_tensor;
353           TF_RETURN_IF_ERROR(
354               SparsifyWeights(weight, &indices_tensor, &values_tensor));
355 
356           // indices and values of sparsified `Const`
357           DataType key_dtype = DT_INT64;
358           NodeDef indices_node;
359           CreateConstNode(indices_tensor,
360                           StrCat(weights_node.name(), "/indices"),
361                           &indices_node);
362           SetNodeAttr("dtype", key_dtype, &indices_node);
363 
364           NodeDef values_node;
365           CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"),
366                           &values_node);
367           SetNodeAttr("dtype", data_type, &values_node);
368 
369           // HashTable node
370           NodeDef hashtable_node;
371           hashtable_node.set_op("HashTable");
372           hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable"));
373           SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
374           SetNodeAttr("value_dtype", data_type, &hashtable_node);
375 
376           // InitializeTable node
377           NodeDef init_table_node;
378           init_table_node.set_op("InitializeTable");
379           init_table_node.set_name(
380               StrCat(weights_node.name(), "/InitializeTable"));
381           SetNodeAttr("Tkey", key_dtype, &init_table_node);
382           SetNodeAttr("Tval", data_type, &init_table_node);
383           init_table_node_names.push_back(init_table_node.name());
384 
385           // LookupTableFind node
386           NodeDef lookup_node;
387           lookup_node.set_op("LookupTableFind");
388           lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
389           SetNodeAttr("Tin", key_dtype, &lookup_node);
390           SetNodeAttr("Tout", data_type, &lookup_node);
391 
392           // Default return value of hashtable lookup
393           Tensor zero_tensor(data_type, TensorShape({}));
394           zero_tensor.flat<float>()(0) = 0.0;
395           NodeDef default_value_node;
396           CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
397                           &default_value_node);
398           SetNodeAttr("dtype", data_type, &default_value_node);
399 
400           // ExpandDims argument
401           Tensor dim_idx(DT_INT32, TensorShape({}));
402           dim_idx.flat<int32>()(0) = -1;
403           NodeDef dim_idx_node;
404           dim_idx_node.set_op("Const");
405           dim_idx_node.set_name(
406               StrCat(gather_node.name(), "/ExpandDims/Const"));
407           SetNodeAttr("value", dim_idx, &dim_idx_node);
408           SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
409 
410           // ExpandDims node
411           NodeDef expand_dims_node;
412           expand_dims_node.set_op("ExpandDims");
413           // Reuse gather_node's name so not to change dependent's inputs
414           expand_dims_node.set_name(gather_node.name());
415           SetNodeAttr("T", data_type, &expand_dims_node);
416 
417           // Connect nodes
418           AddNodeInput(hashtable_node.name(), &init_table_node);
419           refs[hashtable_node.name()]++;
420           AddNodeInput(indices_node.name(), &init_table_node);
421           refs[indices_node.name()]++;
422           AddNodeInput(values_node.name(), &init_table_node);
423           refs[values_node.name()]++;
424 
425           AddNodeInput(hashtable_node.name(), &lookup_node);
426           refs[hashtable_node.name()]++;
427           AddNodeInput(gather_node.input(1), &lookup_node);
428           refs[gather_node.input(1)]++;
429           AddNodeInput(default_value_node.name(), &lookup_node);
430           refs[default_value_node.name()]++;
431 
432           AddNodeInput(lookup_node.name(), &expand_dims_node);
433           refs[lookup_node.name()]++;
434           AddNodeInput(dim_idx_node.name(), &expand_dims_node);
435           refs[dim_idx_node.name()]++;
436 
437           // Copy 'ids' input of original 'Gather'
438           new_nodes->push_back(match.inputs[1].node);
439           new_nodes->push_back(indices_node);
440           new_nodes->push_back(values_node);
441           new_nodes->push_back(hashtable_node);
442           new_nodes->push_back(init_table_node);
443           new_nodes->push_back(lookup_node);
444           new_nodes->push_back(default_value_node);
445           new_nodes->push_back(dim_idx_node);
446           new_nodes->push_back(expand_dims_node);
447 
448           return Status::OK();
449         },
450         {true}, &replaced_graph_def));
451 
452     NodeDef* init_op = nullptr;
453     for (int i = 0; i < replaced_graph_def.node_size(); i++) {
454       if (replaced_graph_def.node(i).name() == group_init_node &&
455           replaced_graph_def.node(i).op() == "NoOp") {
456         init_op = replaced_graph_def.mutable_node(i);
457         break;
458       }
459     }
460     if (!init_op) {
461       // Init node
462       init_op = replaced_graph_def.mutable_node()->Add();
463       init_op->set_op("NoOp");
464       init_op->set_name(group_init_node);
465     }
466     for (const string& name : init_table_node_names) {
467       // Add control dependence from init_table_node to group_deps_node
468       AddNodeInput(StrCat("^", name), init_op);
469       refs[name]++;
470     }
471 
472     // Erase inputs and outputs as they are not considered for deletion.
473     for (const auto& output : context.output_names) {
474       refs.erase(output);
475     }
476 
477     for (const auto& input : context.input_names) {
478       refs.erase(input);
479     }
480 
481     // Add nodes with a reference count of 0 for deletion.
482     for (auto entry : refs) {
483       if (entry.second == 0) {
484         removed_node_names.push_back(entry.first);
485       }
486     }
487 
488     while (!removed_node_names.empty()) {
489       auto name = removed_node_names.back();
490       removed_node_names.pop_back();
491 
492       int i = 0;
493       while (i < replaced_graph_def.node_size()) {
494         // Revisit this to see if we can safely remove RestoreV2 nodes.
495         if ((replaced_graph_def.node(i).name() == name) &&
496             (replaced_graph_def.node(i).op() != "RestoreV2")) {
497           for (const auto& input : replaced_graph_def.node(i).input()) {
498             auto parsed_input = StringReplace(input, "^", "", true);
499             refs[parsed_input] -= 1;
500             if (refs[parsed_input] == 0) {
501               removed_node_names.push_back(parsed_input);
502             }
503           }
504           TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
505           continue;
506         }
507         int j = 0;
508         bool deleted_inputs = false;
509         while (j < replaced_graph_def.node(i).input_size()) {
510           if (replaced_graph_def.node(i).input(j) == name ||
511               replaced_graph_def.node(i).input(j) == ("^" + name)) {
512             TF_RETURN_IF_ERROR(
513                 RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
514             deleted_inputs = true;
515             continue;
516           }
517           j++;
518         }
519         if (deleted_inputs) {
520           if (replaced_graph_def.node(i).op() == "ConcatV2") {
521             if (replaced_graph_def.node(i).input_size() > 2) {
522               SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
523                           replaced_graph_def.mutable_node(i));
524             } else if (replaced_graph_def.node(i).input_size() == 2) {
525               if (refs[replaced_graph_def.node(i).input(1)] != 1) {
526                 return errors::Internal(
527                     "Expect axis tensor of ConcatV2 node to only be referenced "
528                     "once.");
529               }
530               refs[replaced_graph_def.node(i).input(1)] -= 1;
531               removed_node_names.push_back(replaced_graph_def.node(i).input(1));
532               replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
533               replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
534               replaced_graph_def.mutable_node(i)->set_op("Identity");
535             } else {
536               return errors::Internal(
537                   "ConcatV2 should have at least two elements");
538             }
539           }
540           if ((replaced_graph_def.node(i).op() == "Assign" ||
541                replaced_graph_def.node(i).op() == "Reshape" ||
542                replaced_graph_def.node(i).op() == "Equal" ||
543                replaced_graph_def.node(i).op() == "Mean" ||
544                replaced_graph_def.node(i).op() == "ScalarSummary") &&
545               replaced_graph_def.node(i).input_size() == 1) {
546             removed_node_names.push_back(replaced_graph_def.node(i).name());
547           }
548           if (!replaced_graph_def.node(i).input_size()) {
549             removed_node_names.push_back(replaced_graph_def.node(i).name());
550           }
551         }
552         i++;
553       }
554     }
555     current_graph_def = replaced_graph_def;
556   } while (any_match_found);
557   *output_graph_def = current_graph_def;
558   return Status::OK();
559 }
560 
SparsifyGather(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)561 Status SparsifyGather(const GraphDef& input_graph_def,
562                       const TransformFuncContext& context,
563                       GraphDef* output_graph_def) {
564   // clang-format off
565   const OpTypePattern gather_pattern =
566     {"Gather",
567      {
568        {"Identity",
569         {
570           {"Const|Variable|VariableV2"}
571         }
572        },
573        {"*"},
574      }
575     };
576   const OpTypePattern gather_v2_pattern =
577     {"GatherV2",
578       {
579         {"Identity",
580           {
581             {"Const|Variable|VariableV2"}
582           }
583         },
584         {"*"},
585         // GatherV2's axis must be constant.
586         {"Const"},
587       }
588     };
589   // clang-format on
590 
591   GraphDef cleaned_input_graph_def;
592   RemoveAttributes(input_graph_def, {"_output_shapes"},
593                    &cleaned_input_graph_def);
594 
595   GraphDef temp_output;
596 
597   std::unique_ptr<BundleReader> ckpt_reader;
598   TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
599 
600   std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
601   TF_RETURN_IF_ERROR(
602       ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
603 
604   TF_RETURN_IF_ERROR(SparsifyGatherInternal(
605       cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
606       ckpt_reader, &temp_output));
607 
608   TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
609                                             context, gather_v2_pattern,
610                                             ckpt_reader, output_graph_def));
611 
612   return Status::OK();
613 }
614 
615 REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
616 
617 }  // namespace graph_transforms
618 }  // namespace tensorflow
619