1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/allocator.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils/frame.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/tensor_coding.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/saved_tensor_slice_util.h"
50 
51 using tensorflow::strings::StrCat;
52 
53 namespace tensorflow {
54 namespace grappler {
55 namespace {
56 
57 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
58 
59 class LoopInvariantNodeMotionOptimizer {
60  public:
LoopInvariantNodeMotionOptimizer(GraphDef * optimized_graph)61   explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
62       : optimized_graph_(optimized_graph) {}
63   virtual ~LoopInvariantNodeMotionOptimizer() = default;
64   Status Optimize();
65 
66  private:
67   Status FindInvariantNodes(NodeDef* node);
68   Status RevertInvariantNodes();
69   Status MoveInvariantNodes(const int frame_id);
70   Status HandleInvariantNode(NodeDef* node, const int num_outputs,
71                              const int frame_id);
72   Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
73   Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
74 
75   GraphDef* optimized_graph_;  // Not owned.
76   std::unique_ptr<NodeMap> node_map_;
77   std::map<NodeDef*, int> invariant_nodes_;
78   std::set<int> empty_set_;
79   // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
80   std::map<int, std::set<int>> frame_children_;
81   std::map<int, int> frame_parent_;
82   std::map<int, const NodeDef*> loop_cond_;
83   std::map<int, std::vector<NodeDef*>> invariant_enters_;
84   int new_enter_id_;
85 };
86 
HandleInvariantEnter(NodeDef * node,const int num_outputs)87 Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
88     NodeDef* node, const int num_outputs) {
89   auto consumers = node_map_->GetOutputs(node->name());
90   std::vector<string> enter_control_inputs;
91   string enter_input;
92   for (auto& input : node->input()) {
93     if (IsControlInput(input)) {
94       enter_control_inputs.push_back(input);
95     } else {
96       enter_input = input;
97     }
98   }
99   for (auto* consumer : consumers) {
100     if (invariant_nodes_.count(consumer)) {
101       for (int i = 0; i < consumer->input_size(); ++i) {
102         if (NodeName(consumer->input(i)) == node->name()) {
103           consumer->set_input(i, enter_input);
104           node_map_->AddOutput(NodeName(enter_input), consumer->name());
105           node_map_->RemoveOutput(node->name(), consumer->name());
106         }
107       }
108       for (auto& control_input : enter_control_inputs) {
109         consumer->add_input(control_input);
110         node_map_->AddOutput(NodeName(control_input), consumer->name());
111       }
112     }
113   }
114   return Status::OK();
115 }
116 
HandleConst(NodeDef * node,const int num_outputs,const int frame_id)117 Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
118                                                      const int num_outputs,
119                                                      const int frame_id) {
120   NodeDef* const_node = nullptr;
121   if (num_outputs == 0) {
122     // all successor nodes are invariant
123     // Remove the control inputs from this frame to the const node,
124     // when moving it out of the frame (in parent frame)
125     const_node = node;
126     node_map_->RemoveInputs(node->name());
127     node->clear_input();
128   } else {
129     // some successor nodes are variant
130     // Have to keep the const node in the frame,
131     // so create a new one outside the frame (in parent frame)
132     const string const_node_name =
133         AddPrefixToNodeName(node->name(), kLoopOptimizer);
134     const_node = node_map_->GetNode(const_node_name);
135     if (const_node == nullptr) {
136       const_node = optimized_graph_->add_node();
137       const_node->set_name(const_node_name);
138       const_node->set_op("Const");
139       const_node->set_device(node->device());
140       *const_node->mutable_attr() = node->attr();
141       node_map_->AddNode(const_node->name(), const_node);
142     }
143     auto consumers = node_map_->GetOutputs(node->name());
144     for (auto* consumer : consumers) {
145       if (invariant_nodes_.count(consumer)) {
146         for (int i = 0; i < consumer->input_size(); ++i) {
147           if (NodeName(consumer->input(i)) == node->name()) {
148             if (IsControlInput(consumer->input(i))) {
149               *consumer->mutable_input(i) = AsControlDependency(*const_node);
150             } else {
151               *consumer->mutable_input(i) = const_node->name();
152             }
153             node_map_->AddOutput(const_node->name(), consumer->name());
154             node_map_->RemoveOutput(node->name(), consumer->name());
155           }
156         }
157       }
158     }
159   }
160   // add a control input from the parent frame
161   auto parent_it = frame_parent_.find(frame_id);
162   if (parent_it != frame_parent_.end()) {
163     int parent_id = parent_it->second;
164     auto loop_cond_it = loop_cond_.find(parent_id);
165     if (loop_cond_it == loop_cond_.end()) {
166       return errors::InvalidArgument("Frame ", frame_id,
167                                      " doesn't have a LoopCond node");
168     }
169     auto& loop_cond_name = loop_cond_it->second->name();
170     NodeDef* switch_node = nullptr;
171     for (auto* node : node_map_->GetOutputs(loop_cond_name)) {
172       if (node->op() == "Switch") {
173         switch_node = node;
174         break;
175       }
176     }
177     if (!switch_node) {
178       return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
179                                      " doesn't connect to any Switch node");
180     }
181     string switch_output = StrCat(switch_node->name(), ":1");
182     const string ctrl_dep = ConstantFolding::AddControlDependency(
183         switch_output, optimized_graph_, node_map_.get());
184     const_node->add_input(ctrl_dep);
185     node_map_->AddOutput(NodeName(ctrl_dep), const_node->name());
186   }
187   return Status::OK();
188 }
189 
HandleInvariantNode(NodeDef * node,const int num_outputs,const int frame_id)190 Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
191     NodeDef* node, const int num_outputs, const int frame_id) {
192   // have to remove control inputs to the invariant node from the same frame
193   // when moving this node out of this frame
194   for (int i = 0; i < node->input_size(); ++i) {
195     if (IsControlInput(node->input(i))) {
196       node->mutable_input()->SwapElements(i, node->input_size() - 1);
197       node->mutable_input()->RemoveLast();
198     }
199   }
200   if (num_outputs == 0) {
201     return Status::OK();
202   }
203 
204   DataTypeVector input_types;
205   DataTypeVector output_types;
206   OpRegistryInterface* op_registry = OpRegistry::Global();
207   const OpRegistrationData* op_reg_data = nullptr;
208   TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
209   TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
210                                        &output_types));
211 
212   auto consumers = node_map_->GetOutputs(node->name());
213   string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
214   int piterations =
215       invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
216   for (auto* consumer : consumers) {
217     if (!invariant_nodes_.count(consumer)) {
218       for (int i = 0; i < consumer->input_size(); ++i) {
219         int port;
220         string node_name = ParseNodeName(consumer->input(i), &port);
221         if (node_name != node->name()) {
222           continue;
223         }
224         if (port < 0) {
225           return errors::InvalidArgument(
226               "Invariant node should not have control outputs "
227               "to variant node");
228         }
229         DataType output_type = output_types[port];
230         NodeDef* new_enter = optimized_graph_->add_node();
231         new_enter->set_op("Enter");
232         new_enter->set_device(node->device());
233         new_enter->set_name(AddPrefixToNodeName(
234             StrCat(fname, "_enter_", new_enter_id_++), kLoopOptimizer));
235         AttrValue data_type;
236         data_type.set_type(output_type);
237         new_enter->mutable_attr()->insert({"T", data_type});
238         AttrValue frame_name;
239         frame_name.set_s(fname);
240         new_enter->mutable_attr()->insert({"frame_name", frame_name});
241         AttrValue is_const;
242         is_const.set_b(true);
243         new_enter->mutable_attr()->insert({"is_constant", is_const});
244         AttrValue parallel_iterations;
245         parallel_iterations.set_i(piterations);
246         new_enter->mutable_attr()->insert(
247             {"parallel_iterations", parallel_iterations});
248         new_enter->add_input(consumer->input(i));
249         *consumer->mutable_input(i) = new_enter->name();
250         node_map_->AddNode(new_enter->name(), new_enter);
251         node_map_->AddOutput(node->name(), new_enter->name());
252         node_map_->AddOutput(new_enter->name(), consumer->name());
253       }
254     }
255   }
256   return Status::OK();
257 }
258 
MoveInvariantNodes(const int frame_id)259 Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
260     const int frame_id) {
261   for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
262        ++iter) {
263     auto* invariant_node = iter->first;
264     const int num_outputs = iter->second;
265     if (IsEnter(*invariant_node)) {
266       TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
267     } else if (IsConstant(*invariant_node)) {
268       TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
269     } else {
270       TF_RETURN_IF_ERROR(
271           HandleInvariantNode(invariant_node, num_outputs, frame_id));
272     }
273   }
274   return Status::OK();
275 }
276 
RevertInvariantNodes()277 Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
278   std::deque<const NodeDef*> reverted_nodes;
279   for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
280     bool erased = false;
281     const auto* node = iter->first;
282     if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
283       auto& consumers = node_map_->GetOutputs(node->name());
284       for (auto* consumer : consumers) {
285         if (!invariant_nodes_.count(consumer)) {
286           for (const auto& input : consumer->input()) {
287             if (IsControlInput(input) && NodeName(input) == node->name()) {
288               reverted_nodes.push_back(node);
289               invariant_nodes_.erase(iter++);
290               erased = true;
291               break;
292             }
293           }
294           if (erased) break;
295         }
296       }
297     }
298     if (!erased) ++iter;
299   }
300   while (!reverted_nodes.empty()) {
301     const auto* node = reverted_nodes.front();
302     reverted_nodes.pop_front();
303     std::set<NodeDef*> producers;
304     for (const auto& input : node->input()) {
305       auto* producer = node_map_->GetNode(input);
306       auto iter = invariant_nodes_.find(producer);
307       if (iter != invariant_nodes_.end()) {
308         if (IsControlInput(input) && !IsConstant(*producer) &&
309             !IsEnter(*producer)) {
310           reverted_nodes.push_back(producer);
311           invariant_nodes_.erase(iter);
312         } else {
313           producers.insert(producer);
314         }
315       }
316     }
317     for (auto* producer : producers) {
318       auto iter = invariant_nodes_.find(producer);
319       if (iter != invariant_nodes_.end()) {
320         ++iter->second;
321       }
322     }
323     for (auto* consumer : node_map_->GetOutputs(node->name())) {
324       auto iter = invariant_nodes_.find(consumer);
325       if (iter != invariant_nodes_.end()) {
326         reverted_nodes.push_back(consumer);
327         invariant_nodes_.erase(iter);
328       }
329     }
330   }
331   return Status::OK();
332 }
333 
FindInvariantNodes(NodeDef * start_node)334 Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(
335     NodeDef* start_node) {
336   std::vector<NodeDef*> stack;
337   stack.reserve(32);
338   stack.push_back(start_node);
339   while (!stack.empty()) {
340     NodeDef* node = stack.back();
341     stack.pop_back();
342     auto consumers = node_map_->GetOutputs(node->name());
343     invariant_nodes_.emplace(node, consumers.size());
344     for (auto* consumer : consumers) {
345       if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
346         continue;
347       }
348       bool is_invariant = true;
349       for (const auto& input : consumer->input()) {
350         if (!IsControlInput(input)) {
351           const string name = NodeName(input);
352           auto* producer = node_map_->GetNode(name);
353           if (!invariant_nodes_.count(producer)) {
354             if (IsConstant(*producer)) {
355               invariant_nodes_.insert(
356                   std::make_pair(producer, node_map_->GetOutputs(name).size()));
357             } else {
358               is_invariant = false;
359               break;
360             }
361           }
362         }
363       }
364       if (is_invariant) {
365         std::set<NodeDef*> producers;
366         for (const auto& input : consumer->input()) {
367           auto* producer = node_map_->GetNode(input);
368           producers.insert(producer);
369         }
370         for (auto* producer : producers) {
371           auto iter = invariant_nodes_.find(producer);
372           if (iter != invariant_nodes_.end()) {
373             --iter->second;
374           }
375         }
376         stack.push_back(consumer);
377       }
378     }
379   }
380   return Status::OK();
381 }
382 
Optimize()383 Status LoopInvariantNodeMotionOptimizer::Optimize() {
384   node_map_.reset(new NodeMap(optimized_graph_));
385   FrameView frame_view;
386   // TODO(ezhulenev): Use GraphView when migrated from NodeMap.
387   TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*optimized_graph_));
388 
389   std::deque<int> worklist;
390   for (const NodeDef& node : optimized_graph_->node()) {
391     const std::vector<int>& frame_ids = frame_view.Frames(node);
392 
393     if (frame_ids.size() >= 3) {
394       for (unsigned int i = 1; i < frame_ids.size() - 1; ++i) {
395         frame_parent_[frame_ids[i]] = frame_ids[i - 1];
396         frame_children_[frame_ids[i]].insert(frame_ids[i + 1]);
397       }
398     }
399     if (frame_ids.size() >= 2) {
400       frame_children_[frame_ids[0]].insert(frame_ids[1]);
401       frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
402     }
403     if (!frame_ids.empty()) {
404       frame_children_.insert(std::make_pair(frame_ids.back(), empty_set_));
405       if (node.op() == "LoopCond") {
406         if (loop_cond_.count(frame_ids.back())) {
407           return errors::InvalidArgument(
408               "Loop ", frame_ids.back(),
409               " has more than one LoopCond node: ", node.name(), " and ",
410               loop_cond_[frame_ids.back()]->name());
411         }
412         loop_cond_[frame_ids.back()] = &node;
413       }
414       if (IsEnter(node) && node.attr().at("is_constant").b()) {
415         invariant_enters_[frame_ids.back()].push_back(
416             const_cast<NodeDef*>(&node));
417       }
418     }
419   }
420 
421   for (auto it = frame_children_.begin(); it != frame_children_.end(); ++it) {
422     if (it->second.empty()) {
423       worklist.push_back(it->first);
424     }
425   }
426 
427   while (!worklist.empty()) {
428     int frame_id = worklist.front();
429     new_enter_id_ = 0;
430     worklist.pop_front();
431     auto parent_it = frame_parent_.find(frame_id);
432     if (parent_it != frame_parent_.end()) {
433       int parent_id = parent_it->second;
434       frame_children_[parent_id].erase(frame_id);
435       if (frame_children_[parent_id].empty()) {
436         worklist.push_back(parent_id);
437       }
438     }
439 
440     if (invariant_enters_[frame_id].empty()) {
441       continue;
442     }
443     invariant_nodes_.clear();
444     for (auto* enter : invariant_enters_[frame_id]) {
445       TF_RETURN_IF_ERROR(FindInvariantNodes(enter));
446     }
447 
448     // revert invariant nodes that have control outputs to variant nodes
449     TF_RETURN_IF_ERROR(RevertInvariantNodes());
450 
451     TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id));
452   }
453   return Status::OK();
454 }
455 
GetStackPushNodesToConvert(const GraphTopologyView & graph_view,const std::unordered_set<string> & nodes_to_preserve,int stack_node_idx)456 std::vector<int> GetStackPushNodesToConvert(
457     const GraphTopologyView& graph_view,
458     const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
459   VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
460 
461   const std::unordered_set<string> op_types_to_traverse(
462       {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
463        "Identity", "RefIdentity"});
464   const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
465     return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
466   };
467 
468   std::vector<int> nodes_to_convert;
469   std::vector<int> fanouts;
470 
471   DfsTraversal(graph_view, {graph_view.GetNode(stack_node_idx)},
472                TraversalDirection::kFollowOutputs,
473                DfsPredicates::Advance(is_op_to_traverse),
474                DfsCallbacks::PreOrder([&](const NodeDef* node) {
475                  const absl::optional<int> idx = graph_view.GetNodeIndex(*node);
476                  fanouts.push_back(idx.value());
477                }));
478 
479   for (int fanout_idx : fanouts) {
480     const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
481     VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
482     if (IsStackPushOp(fanout_node)) {
483       // Check that the stack itself is not a node we want to preserve. This can
484       // happen when the graph we have contains only the forward pass for a loop
485       // (as when the forward and backward passes are split across different
486       // functions).
487       if (graph_view.HasNode(fanout_node.input(0))) {
488         const NodeDef* stack_node = graph_view.GetNode(fanout_node.input(0));
489         while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
490                stack_node->input_size() > 0 &&
491                graph_view.HasNode(stack_node->input(0))) {
492           stack_node = graph_view.GetNode(stack_node->input(0));
493         }
494         if (nodes_to_preserve.find(stack_node->name()) ==
495             nodes_to_preserve.end()) {
496           nodes_to_convert.push_back(fanout_idx);
497         }
498       } else {
499         nodes_to_convert.push_back(fanout_idx);
500       }
501     } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
502                op_types_to_traverse.find(fanout_node.op()) !=
503                    op_types_to_traverse.end()) {
504       continue;
505     } else if (!IsStackPopOp(fanout_node) ||
506                (!graph_view.GetFanout(fanout_idx).empty() ||
507                 nodes_to_preserve.find(fanout_node.name()) !=
508                     nodes_to_preserve.end())) {
509       // The node is either a stack pop with consumers or something unexpected
510       // so we leave the graph alone.
511       nodes_to_convert.clear();
512       break;
513     }
514   }
515 
516   return nodes_to_convert;
517 }
518 
RemoveStackOps(const std::unordered_set<string> & nodes_to_preserve,GraphDef * optimized_graph)519 Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
520                       GraphDef* optimized_graph) {
521   NodeMap node_map(optimized_graph);
522   GraphTopologyView graph_view;
523   TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(*optimized_graph));
524 
525   for (int node_idx = 0; node_idx < optimized_graph->node_size(); ++node_idx) {
526     if (IsStackOp(optimized_graph->node(node_idx))) {
527       for (int push_node_idx : GetStackPushNodesToConvert(
528                graph_view, nodes_to_preserve, node_idx)) {
529         // We found push nodes without corresponding pops. Convert them to
530         // Identity passing the data through and add a control dependency from
531         // the op supplying the stack handle.
532         NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
533         VLOG(1) << "Converting " << push_node_idx << " : "
534                 << push_node->DebugString();
535         if (push_node->attr().count("swap_memory") != 0) {
536           push_node->mutable_attr()->erase("swap_memory");
537         }
538         push_node->set_op("Identity");
539         push_node->mutable_input()->SwapElements(0, 1);
540         const string ctrl_dep = ConstantFolding::AddControlDependency(
541             push_node->input(1), optimized_graph, &node_map);
542         push_node->set_input(1, ctrl_dep);
543         VLOG(1) << "After converting: " << push_node->DebugString();
544       }
545     }
546   }
547   return Status::OK();
548 }
549 
IsSimpleBinaryOperator(const NodeDef & node)550 bool IsSimpleBinaryOperator(const NodeDef& node) {
551   return (IsLess(node) || IsLessEqual(node) || IsGreater(node) ||
552           IsGreaterEqual(node) || IsEqual(node));
553 }
554 
EvaluateBoolOpForConstantOperands(const NodeDef & op_node,const NodeDef & constant_operand_0,const NodeDef & constant_operand_1,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * value)555 Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
556                                          const NodeDef& constant_operand_0,
557                                          const NodeDef& constant_operand_1,
558                                          DeviceBase* cpu_device,
559                                          ResourceMgr* resource_mgr,
560                                          bool* value) {
561   TensorVector inputs;
562 
563   const TensorProto& raw_val_0 = constant_operand_0.attr().at("value").tensor();
564   Tensor value_0(raw_val_0.dtype(), raw_val_0.tensor_shape());
565   CHECK(value_0.FromProto(raw_val_0));
566   inputs.emplace_back(&value_0);
567   const TensorProto& raw_val_1 = constant_operand_1.attr().at("value").tensor();
568   Tensor value_1(raw_val_1.dtype(), raw_val_1.tensor_shape());
569   CHECK(value_1.FromProto(raw_val_1));
570   inputs.emplace_back(&value_1);
571 
572   TensorVector outputs;
573   TF_RETURN_IF_ERROR(
574       EvaluateNode(op_node, inputs, cpu_device, resource_mgr, &outputs));
575 
576   if (outputs.size() != 1 || outputs[0].tensor == nullptr) {
577     return Status(error::INVALID_ARGUMENT, "Expected one output.");
578   }
579   *value = outputs[0].tensor->scalar<bool>()();
580   delete outputs[0].tensor;
581 
582   return Status::OK();
583 }
584 
585 // TODO(lyandy): Consolidate with ConstantFolding implementation.
IsReallyConstant(const NodeDef & node,const absl::flat_hash_set<string> & feed_nodes)586 bool IsReallyConstant(const NodeDef& node,
587                       const absl::flat_hash_set<string>& feed_nodes) {
588   if (!IsConstant(node)) {
589     return false;
590   }
591   // If the node is fed it's not constant anymore.
592   return feed_nodes.find(node.name()) == feed_nodes.end();
593 }
594 
CheckForDeadFanout(const MutableGraphView & view,const NodeDef & switch_node,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * has_dead_fanout,int * dead_fanout)595 Status CheckForDeadFanout(const MutableGraphView& view,
596                           const NodeDef& switch_node, const NodeMap& node_map,
597                           const absl::flat_hash_set<string>& feed_nodes,
598                           DeviceBase* cpu_device, ResourceMgr* resource_mgr,
599                           bool* has_dead_fanout, int* dead_fanout) {
600   *has_dead_fanout = false;
601   GraphView::InputPort switch_loopcond_port(&switch_node, 1);
602   const NodeDef* switch_predicate =
603       view.GetRegularFanin(switch_loopcond_port).node;
604 
605   // CASE 1: Control is a constant.
606   if (IsReallyConstant(*switch_predicate, feed_nodes)) {
607     Tensor selector;
608     CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
609     *has_dead_fanout = true;
610     *dead_fanout = selector.scalar<bool>()() ? 0 : 1;
611   }
612 
613   GraphView::InputPort switch_input_port(&switch_node, 0);
614   const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
615 
616   // CASE 2: Zero-iteration while loop.
617   // We check if its a while loop such that the condition is a simple binary
618   // operator which returns false for the initialization value.
619   // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
620   if (!IsMerge(*switch_input)) {
621     return Status::OK();
622   }
623 
624   // Find the boolean Op from predicate node.
625   NodeDef* switch_ctrl_node = nullptr;
626   for (int i = 0; i < switch_predicate->input().size(); ++i) {
627     NodeDef* node = node_map.GetNode(switch_predicate->input(i));
628     if (IsSimpleBinaryOperator(*node)) {
629       switch_ctrl_node = node;
630     }
631   }
632   if (switch_ctrl_node == nullptr) {
633     return Status::OK();
634   }
635   // Find the Merge node & the Constant Operand to the condition node, if
636   // available.
637   NodeDef* merge_node = nullptr;
638   NodeDef* constant_ctrl_input = nullptr;
639   int constant_index = 0;
640   for (int i = 0; i < switch_ctrl_node->input().size(); ++i) {
641     NodeDef* node = node_map.GetNode(switch_ctrl_node->input(i));
642     if (IsMerge(*node)) {
643       merge_node = node;
644     }
645     if (IsReallyConstant(*node, feed_nodes)) {
646       constant_ctrl_input = node;
647       constant_index = i;
648     }
649   }
650   if (merge_node == nullptr || constant_ctrl_input == nullptr) {
651     return Status::OK();
652   }
653   // Find the initialization constant (via Enter, if one exists).
654   NodeDef* enter_node = nullptr;
655   NodeDef* constant_init_node = nullptr;
656   for (const auto& input : merge_node->input()) {
657     NodeDef* node = node_map.GetNode(input);
658     if (IsEnter(*node)) {
659       enter_node = node;
660     }
661     if (IsReallyConstant(*node, feed_nodes)) {
662       constant_init_node = node;
663     }
664   }
665   if (enter_node != nullptr) {
666     if (constant_init_node != nullptr) return Status::OK();
667     for (const auto& input : enter_node->input()) {
668       NodeDef* node = node_map.GetNode(input);
669       if (IsReallyConstant(*node, feed_nodes)) {
670         constant_init_node = node;
671       }
672     }
673   }
674   if (constant_init_node == nullptr) {
675     return Status::OK();
676   }
677 
678   // Check if there will be 0 iterations. This will only happen if the condition
679   // evaluates to false with respect to the initialization value.
680   NodeDef* operand_0 =
681       constant_index ? constant_init_node : constant_ctrl_input;
682   NodeDef* operand_1 =
683       constant_index ? constant_ctrl_input : constant_init_node;
684   bool constant_switch_value;
685   TF_RETURN_IF_ERROR(EvaluateBoolOpForConstantOperands(
686       *switch_ctrl_node, *operand_0, *operand_1, cpu_device, resource_mgr,
687       &constant_switch_value));
688   if (constant_switch_value == false) {
689     *has_dead_fanout = true;
690     *dead_fanout = 1;
691   }
692   return Status::OK();
693 }
694 
695 }  // namespace
696 
LoopOptimizer()697 LoopOptimizer::LoopOptimizer()
698     : opt_level_(RewriterConfig::ON),
699       cpu_device_(nullptr),
700       options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
701 
LoopOptimizer(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)702 LoopOptimizer::LoopOptimizer(RewriterConfig::Toggle opt_level,
703                              DeviceBase* cpu_device)
704     : opt_level_(opt_level),
705       cpu_device_(cpu_device),
706       options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {
707   resource_mgr_.reset(new ResourceMgr());
708 }
709 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)710 Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
711                                GraphDef* optimized_graph) {
712   *optimized_graph = item.graph;
713   // Set up helper data structures.
714   if (options_.enable_loop_invariant_node_motion) {
715     LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
716     TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
717   }
718   if (options_.enable_stack_push_removal) {
719     TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
720   }
721   if (options_.enable_dead_branch_removal) {
722     // TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
723     // optimizer passes.
724     NodeMap node_map(optimized_graph);
725     absl::flat_hash_set<string> feed_nodes;
726     for (const auto& feed : item.feed) {
727       feed_nodes.insert(NodeName(feed.first));
728     }
729     TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
730                                           feed_nodes, optimized_graph));
731   }
732 
733   return Status::OK();
734 }
735 
RemoveDeadBranches(const std::unordered_set<string> & nodes_to_preserve,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,GraphDef * optimized_graph)736 Status LoopOptimizer::RemoveDeadBranches(
737     const std::unordered_set<string>& nodes_to_preserve,
738     const NodeMap& node_map, const absl::flat_hash_set<string>& feed_nodes,
739     GraphDef* optimized_graph) {
740   std::unordered_set<const NodeDef*> dead_nodes;
741   std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
742   // TODO(bsteiner): also rewrite switches as identity. For now we just record
743   // them
744   absl::flat_hash_set<GraphView::OutputPort> identity_switches;
745 
746   MutableGraphView view(optimized_graph);
747   for (const NodeDef& node : optimized_graph->node()) {
748     if (!IsSwitch(node)) {
749       continue;
750     }
751     if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
752       continue;
753     }
754 
755     int dead_fanout;
756     bool has_dead_fanout;
757     TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
758                                           cpu_device_, resource_mgr_.get(),
759                                           &has_dead_fanout, &dead_fanout));
760     if (!has_dead_fanout) {
761       continue;
762     }
763     GraphView::OutputPort dead(&node, dead_fanout);
764     identity_switches.insert(dead);
765 
766     SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
767         zombie_inputs;
768     for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
769       if (dead_nodes.find(port.node) == dead_nodes.end()) {
770         zombie_inputs.PushBack(port);
771       }
772     }
773     // If we encounter a single node that must be preserved in the fanout of the
774     // switch node we need to preserve the entire switch fanout: we therefore
775     // work on a local copy that only gets committed to the master copy once the
776     // whole fanout has been explored.
777     std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
778     std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
779         dead_merge_inputs;
780     bool found_node_to_preserve = false;
781     while (!found_node_to_preserve && !zombie_inputs.Empty()) {
782       MutableGraphView::InputPort dead = zombie_inputs.PopBack();
783       if (nodes_to_preserve.find(dead.node->name()) !=
784           nodes_to_preserve.end()) {
785         found_node_to_preserve = true;
786         break;
787       }
788 
789       if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
790         continue;
791       }
792 
793       if (IsMerge(*dead.node)) {
794         const int num_data_inputs = dead.node->attr().at("N").i();
795         if (num_data_inputs > 2) {
796           // This never happens in practice, so we'll just skip these to
797           // simplify the code for now.
798           found_node_to_preserve = true;
799           break;
800         }
801         MutableGraphView::OutputPort value_index(dead.node, 1);
802         const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
803             view.GetFanout(value_index);
804         if (!index_fanout.empty()) {
805           // The 2nd output (that indicates which input is propagated) is
806           // connected. This never happens in practice, so we'll just skip this
807           // case to simplify the code for now.
808           found_node_to_preserve = true;
809           break;
810         }
811 
812         bool fully_dead = false;
813         // Merge node can become real dead only if all data inputs are dead.
814         // Merge always waits for all control edges, but they do not
815         // change the node deadness.
816         if (dead.port_id >= 0) {
817           local_dead_merge_inputs[dead.node].insert(dead.port_id);
818           if (local_dead_merge_inputs[dead.node].size() == num_data_inputs) {
819             fully_dead = true;
820           }
821         } else {
822           // Keep track of all Merge nodes, even if they do not have dead data
823           // inputs. We'll need to cleanup dead control edges for them later.
824           local_dead_merge_inputs.insert({dead.node, {}});
825         }
826         if (fully_dead) {
827           local_dead_merge_inputs.erase(dead.node);
828           local_dead_nodes.insert(dead.node);
829           for (const MutableGraphView::InputPort& port :
830                view.GetFanouts(*dead.node, true)) {
831             zombie_inputs.PushBack(port);
832           }
833         }
834       } else if (dead.node->op() == "ControlTrigger") {
835         // Control trigger have different semantic, so don't touch them
836         found_node_to_preserve = true;
837         break;
838       } else {
839         if (local_dead_nodes.insert(dead.node).second) {
840           for (const MutableGraphView::InputPort& dead_fanout :
841                view.GetFanouts(*dead.node, true)) {
842             zombie_inputs.PushBack(dead_fanout);
843           }
844         }
845       }
846     }
847     if (!found_node_to_preserve) {
848       std::swap(dead_nodes, local_dead_nodes);
849       std::swap(dead_merge_inputs, local_dead_merge_inputs);
850     }
851   }
852 
853   std::vector<int> nodes_idx_to_delete;
854   nodes_idx_to_delete.reserve(dead_nodes.size());
855   for (int i = 0; i < optimized_graph->node_size(); ++i) {
856     if (dead_nodes.count(&optimized_graph->node(i)))
857       nodes_idx_to_delete.push_back(i);
858   }
859 
860   // Names of the nodes that were removed from the graph.
861   absl::flat_hash_set<absl::string_view> dead_node_names;
862   dead_node_names.reserve(dead_nodes.size());
863   for (const NodeDef* dead_node : dead_nodes)
864     dead_node_names.insert(dead_node->name());
865 
866   // Remove dead inputs from Merge nodes that were not pruned from the graph.
867   for (const auto& itr : dead_merge_inputs) {
868     NodeDef* dead_node = itr.first;
869     if (dead_nodes.find(dead_node) != dead_nodes.end()) {
870       // The node has been pruned since all its inputs are dead.
871       continue;
872     }
873     // Remove dead data input.
874     const std::set<int>& dead_inputs = itr.second;
875     for (int index : dead_inputs) {
876       dead_node->mutable_input()->DeleteSubrange(index, 1);
877     }
878     // Turn Merge into Identity only if we deleted data inputs.
879     if (!dead_inputs.empty()) {
880       dead_node->set_op("Identity");
881       dead_node->mutable_attr()->erase("N");
882     }
883     // Remove control inputs from dead nodes.
884     int pos = 0;
885     while (pos < dead_node->input_size()) {
886       TensorId tensor = ParseTensorName(dead_node->input(pos));
887       if (tensor.index() == Graph::kControlSlot &&
888           dead_node_names.contains(tensor.node())) {
889         auto* inputs = dead_node->mutable_input();
890         inputs->SwapElements(pos, dead_node->input_size() - 1);
891         inputs->RemoveLast();
892       } else {
893         ++pos;
894       }
895     }
896   }
897 
898   EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
899 
900   return Status::OK();
901 }
902 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)903 void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
904                              const GraphDef& /*optimized_graph*/,
905                              double /*result*/) {
906   // Nothing to do for LoopOptimizer.
907 }
908 
909 }  // end namespace grappler
910 }  // end namespace tensorflow
911