1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/attr_value_util.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/graph_topology_view.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/op_types.h"
39 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
40 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/grappler/utils/topological_sort.h"
44 #include "tensorflow/core/grappler/utils/traversal.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/tensor_coding.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/saved_tensor_slice_util.h"
53 #include "tensorflow/core/util/strided_slice_op.h"
54 
55 using tensorflow::str_util::StringReplace;
56 using tensorflow::strings::StrCat;
57 
58 namespace tensorflow {
59 namespace grappler {
60 namespace {
61 
62 // Mark nodes created or optimized by a stage with a tag.
63 constexpr char kAddOpsRewriteTag[] =
64     "_grappler:ArithmeticOptimizer:AddOpsRewriteStage";
65 constexpr char kMinimizeBroadcastsTag[] =
66     "_grappler:ArithmeticOptimizer:MinimizeBroadcasts";
67 
68 // Extract values from a Const op to `values`. Returns true if succeeds.
69 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)70 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
71   if (node.op() != "Const") {
72     return false;
73   }
74 
75   if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
76       node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
77     return false;
78   }
79 
80   // TensorProto represents the content of the tensor in either <type>_val or
81   // tensor_content.
82   const TensorProto& tensor = node.attr().at("value").tensor();
83   typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
84       checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
85 
86   if (!tensor_values->empty() && tensor.has_tensor_shape()) {
87     // When tensor_shape is set, theoretically the representation of the data
88     // could be compressed. So, before copying values to the returned vector,
89     // make sure no compression happens.
90     const TensorShapeProto& shape = tensor.tensor_shape();
91     if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
92       values->insert(values->end(), tensor_values->begin(),
93                      tensor_values->end());
94       return true;
95     }
96   }
97 
98   const auto tensor_content_size = tensor.tensor_content().size();
99   if (tensor_content_size > 0) {
100     CHECK_EQ(0, tensor_content_size % sizeof(T))
101         << "tensor_content_size (" << tensor_content_size
102         << ") is not a multiple of " << sizeof(T);
103     values->resize(tensor_content_size / sizeof(T));
104     port::CopyToArray(tensor.tensor_content(),
105                       reinterpret_cast<char*>(values->data()));
106     return true;
107   }
108 
109   return false;
110 }
111 
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)112 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
113                           GraphDef* graph, NodeMap* node_map) {
114   bool already_exists = false;
115   for (const string& input : node->input()) {
116     if (input == new_input || AsControlDependency(input) == new_input) {
117       already_exists = true;
118       break;
119     }
120   }
121   if (!already_exists) {
122     const string ctrl_dep =
123         ConstantFolding::AddControlDependency(new_input, graph, node_map);
124     node->add_input(ctrl_dep);
125     node_map->AddOutput(NodeName(new_input), node->name());
126   }
127   return !already_exists;
128 }
129 
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)130 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
131   (*node->mutable_attr())[attr_name].set_type(dtype);
132 }
133 
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)134 NodeDef* GetTailOfValuePreservingChain(
135     const NodeDef& node, const NodeMap& node_map,
136     const std::unordered_set<string>& nodes_to_preserve) {
137   auto is_value_preserving_non_branching = [&](const NodeDef& node) {
138     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
139            IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
140   };
141   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
142                         is_value_preserving_non_branching);
143 }
144 
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)145 NodeDef* GetTailOfIdempotentChain(
146     const NodeDef& node, const NodeMap& node_map,
147     const std::unordered_set<string>& nodes_to_preserve) {
148   auto is_idempotent_non_branching = [&](const NodeDef& node) {
149     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
150            IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
151   };
152   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
153                         is_idempotent_non_branching);
154 }
155 
156 // GetElementUnexhaustive tries to get the value of an element in a tensor and
157 // turn it into complex128 type. It only check for a limited number of data
158 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)159 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
160                             complex128* element) {
161   if (dtypes.find(t.dtype()) == dtypes.end()) return false;
162   switch (t.dtype()) {
163     case DT_BFLOAT16:
164       *element = complex128(t.flat<bfloat16>()(i));
165       return true;
166     case DT_HALF:
167       *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
168       return true;
169     case DT_INT32:
170       *element = complex128(t.flat<int32>()(i));
171       return true;
172     case DT_INT64:
173       *element = complex128(t.flat<int64>()(i));
174       return true;
175     case DT_FLOAT:
176       *element = complex128(t.flat<float>()(i));
177       return true;
178     case DT_DOUBLE:
179       *element = complex128(t.flat<double>()(i));
180       return true;
181     case DT_COMPLEX64:
182       *element = complex128(t.flat<complex64>()(i));
183       return true;
184     case DT_COMPLEX128:
185       *element = t.flat<complex128>()(i);
186       return true;
187     default:
188       return false;
189   }
190 }
191 
192 // Graph optimizer context extension specific to ArithmeticOptimizer.
193 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon327bfa1e0111::ArithmeticOptimizerContext194   explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
195       : nodes_to_simplify(nodes_to_simplify) {}
196   SetVector<NodeDef*>* nodes_to_simplify;
197 };
198 
199 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
200 // AddOps optimization, etc...
201 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
202  public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)203   explicit ArithmeticOptimizerStage(const string& name,
204                                     const GraphOptimizerContext& ctx,
205                                     const ArithmeticOptimizerContext ctx_ext)
206       : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
207         ctx_ext_(ctx_ext) {}
208   ~ArithmeticOptimizerStage() override = default;
209 
210  protected:
211   // Simplification graph rewrite can create additional nodes that are inputs
212   // to final simplified node, they can be also added to the arithmetic
213   // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)214   void AddToOptimizationQueue(NodeDef* node) {
215     ctx_ext_.nodes_to_simplify->PushBack(node);
216   }
217 
218   // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
219   // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)220   void ForwardControlDependencies(
221       NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
222     for (const auto& src : src_nodes) {
223       for (int i = src->input_size() - 1; i >= 0; --i) {
224         if (IsControlInput(src->input(i))) {
225           *target_node->add_input() = src->input(i);
226           ctx().node_map->AddOutput(NodeName(src->input(i)),
227                                     target_node->name());
228         } else {
229           break;
230         }
231       }
232     }
233     DedupControlInputs(target_node);
234   }
235 
IsInPreserveSet(const NodeDef & node) const236   bool IsInPreserveSet(const NodeDef& node) const {
237     return ctx().nodes_to_preserve->find(node.name()) !=
238            ctx().nodes_to_preserve->end();
239   }
240 
241   // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const242   bool IsDrivenByControlDependency(const NodeDef& node) const {
243     return std::any_of(
244         node.input().begin(), node.input().end(),
245         [](const string& input) { return IsControlInput(input); });
246   }
247 
248   // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const249   bool DrivesControlDependency(const NodeDef& node) const {
250     for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
251       for (int i = 0; i < output->input_size(); ++i) {
252         const TensorId tensor = ParseTensorName(output->input(i));
253         if (tensor.node() == node.name() && tensor.index() < 0) {
254           return true;
255         }
256       }
257     }
258     return false;
259   }
260 
261  private:
262   // Extended context required for ArithmeticOptimizer.
263   const ArithmeticOptimizerContext ctx_ext_;
264 };
265 
266 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
267 // group of nodes from the optimized graph.
268 //
269 // * AddOpsRewrite:
270 //   Rewrite a group of Add/AddN with compact Add/AddN tree
271 //
272 // * MinimizeBroadcasts:
273 //   Rewrite a group of binary associative ops, reordering
274 //   inputs, to minimize the cost of broadcast
275 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
276  public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)277   explicit ArithmeticNodesGroupOptimizerStage(
278       const string& name, const GraphOptimizerContext& ctx,
279       const ArithmeticOptimizerContext ctx_ext)
280       : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
281   ~ArithmeticNodesGroupOptimizerStage() override = default;
282 
283   // Input name with a statically inferred shape from GraphProperties
284   struct InputAndShape {
InputAndShapetensorflow::grappler::__anon327bfa1e0111::ArithmeticNodesGroupOptimizerStage::InputAndShape285     InputAndShape(const string& input, const TensorShapeProto& shape)
286         : input(input), shape(shape) {}
287     string input;
288     TensorShapeProto shape;
289   };
290 
291   // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
292   // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
293   // obtained by graph traversal, starting from a root node.
294   struct OptimizedNodesGroup {
295     NodeDef* root_node;
296     TensorShapeProto root_shape;
297     // Optimized nodes that will be updated or removed by rewrite
298     std::vector<NodeDef*> optimized_nodes;
299     // Inputs to optimized nodes
300     std::vector<InputAndShape> inputs;
301   };
302 
TrySimplify(NodeDef * node,string * simplified_node_name)303   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
304     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
305 
306     OptimizedNodesGroup group;
307     TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
308 
309     if (!group.optimized_nodes.empty()) {
310       *simplified_node_name = RewriteOptimizedNodesGroup(group);
311     }
312 
313     return Status::OK();
314   }
315 
316  protected:
317   // Modify the optimized graph after nodes group was successfully identified
318   virtual string RewriteOptimizedNodesGroup(
319       const OptimizedNodesGroup& group) = 0;
320 
321   // Check if input can become a part of current optimized nodes group.
322   virtual bool IsAbsorbableByOptimizedNodesGroup(
323       const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
324 
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const325   Status AbsorbInputByOptimizedNodesGroup(const string& input,
326                                           OptimizedNodesGroup* group) const {
327     std::deque<const string*> input_tensors;
328     input_tensors.push_front(&input);
329 
330     while (!input_tensors.empty()) {
331       const string* input_tensor = input_tensors.front();
332       input_tensors.pop_front();
333 
334       // Get a node for the input tensor.
335       NodeDef* input_node;
336       TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
337 
338       if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
339         group->optimized_nodes.push_back(input_node);
340         for (int i = input_node->input_size() - 1; i >= 0; --i) {
341           const string& absorbed_node_input = input_node->input(i);
342           // TODO(ezhulenev): support control inputs
343           if (IsControlInput(absorbed_node_input)) continue;
344           input_tensors.push_front(&absorbed_node_input);
345         }
346       } else {
347         // If input node can't be absorbed, add it to OptimizedNodesGroup input.
348         OpInfo::TensorProperties properties;
349         TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
350         group->inputs.emplace_back(*input_tensor, properties.shape());
351       }
352     }
353 
354     return Status::OK();
355   }
356 
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const357   Status CreateOptimizedNodesGroup(NodeDef* root_node,
358                                    OptimizedNodesGroup* group) const {
359     OpInfo::TensorProperties root_node_output_properties;
360     TF_RETURN_IF_ERROR(
361         GetTensorProperties(root_node->name(), &root_node_output_properties));
362 
363     group->root_node = root_node;
364     group->root_shape = root_node_output_properties.shape();
365 
366     group->optimized_nodes.reserve(root_node->input_size());
367     for (int i = 0; i < root_node->input_size(); ++i) {
368       const string& input_i = root_node->input(i);
369       // TODO(ezhulenev): add support for control inputs
370       if (IsControlInput(input_i)) continue;
371       TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
372     }
373 
374     return Status::OK();
375   }
376 
377   // Check if all inputs can be broadcasted to the same shape
378   // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const379   bool HasAllInputsBroadcastableToShape(
380       const NodeDef& node, const OpInfo::TensorProperties& properties) const {
381     auto is_broadcastable = [this, &properties](const string& input) {
382       OpInfo::TensorProperties input_props;
383       Status has_input_properties = GetTensorProperties(input, &input_props);
384       return has_input_properties.ok() &&
385              ShapesBroadcastable(properties, input_props);
386     };
387     return std::all_of(node.input().begin(), node.input().end(),
388                        is_broadcastable);
389   }
390 
ShapeSignature(const TensorShapeProto & shape) const391   string ShapeSignature(const TensorShapeProto& shape) const {
392     string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
393     for (int i = 0; i < shape.dim_size(); ++i)
394       strings::StrAppend(&signature, ":", shape.dim(i).size());
395     return signature;
396   }
397 
MarkWithTag(const StringPiece tag,NodeDef * node)398   void MarkWithTag(const StringPiece tag, NodeDef* node) {
399     AddNodeAttr(tag, true, node);
400   }
401 
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const402   void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
403                              const StringPiece tag) const {
404     AddNodeAttr(tag, true, group.root_node);
405     for (NodeDef* optimized_node : group.optimized_nodes) {
406       AddNodeAttr(tag, true, optimized_node);
407     }
408   }
409 
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const410   bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
411                          const NodeDef& node) const {
412     return group.root_node->device() == node.device();
413   }
414 
IsInPreserveSet(const NodeDef & node) const415   bool IsInPreserveSet(const NodeDef& node) const {
416     return ctx().nodes_to_preserve->find(node.name()) !=
417            ctx().nodes_to_preserve->end();
418   }
419 
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const420   bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
421     return HasNodeAttr(node, tag);
422   }
423 
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const424   bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
425                           const StringPiece tag2) const {
426     return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
427   }
428 };
429 
430 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
431 // original inputs of absorbed nodes.
432 //
433 // 1) All nodes must have the same device placement.
434 //
435 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
436 //    optimized to a single AddN node.
437 //
438 //                AddN_1
439 //             /    |    \
440 //          Add_1   z   Add_2       -> AddN(x, y, z, w, q, e)
441 //          /  \        /  \
442 //         x    y      w    Add_3
443 //                          / \
444 //                         q   e
445 //
446 // 3) If some nodes have different shape (it needs to be broadcastable to the
447 //    shape of a "root), tree is optimized to AddNs for symbolically equal
448 //    shapes, and a tree of Add ops, that minimize broadcasts.
449 //
450 //                AddN_1                                 Add
451 //             /    |    \                              /  \
452 //          Add_1   z   Add_2       ->               Add    w
453 //          /  \        /  \                        /   \
454 //         x    y      w    Add_3      AddN(x, y, q, e)  z
455 //                          / \
456 //                         q   e
457 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
458  public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)459   explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
460                               const ArithmeticOptimizerContext& ctx_ext)
461       : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
462   ~AddOpsRewriteStage() override = default;
463 
464   // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const465   bool IsSupported(const NodeDef* node) const override {
466     if (!CanOptimize(*node)) return false;
467 
468     // shape must be symbolically defined and all inputs compatible with it
469     OpInfo::TensorProperties properties;
470     Status has_properties = GetTensorProperties(node->name(), &properties);
471     return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
472            HasAllInputsBroadcastableToShape(*node, properties);
473   }
474 
475  protected:
476   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const477   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
478                                          const NodeDef& node) const override {
479     if (!CanOptimize(node)) return false;
480 
481     if (!IsOnTheSameDevice(group, node)) {
482       return false;
483     }
484     // with a single output data consumer (presumably if we reach this node from
485     // previously absorbed or a root node, it means that this node is not used
486     // as an input to any other op, outside of the group)
487     if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
488       return false;
489     }
490     // All input shapes must be broadcastable to the node shape
491     OpInfo::TensorProperties properties;
492     Status has_properties = GetTensorProperties(node.name(), &properties);
493     return has_properties.ok() &&
494            HasAllInputsBroadcastableToShape(node, properties);
495   }
496 
497   // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const498   bool CanOptimize(const NodeDef& node) const {
499     // TODO(ezhulenev): check if AccumulateNV2 can be supported too
500     if (!IsAdd(node) && !IsAddN(node)) {
501       return false;
502     }
503     if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
504       return false;
505     }
506     // TODO(ezhulenev): relax this condition for root node
507     return !(IsDrivenByControlDependency(node) ||
508              DrivesControlDependency(node));
509   }
510 
511   // Rewrite a group of add ops into a single AddN if all input shapes are
512   // symbolically equal. If not, create AddN for equal shapes first, and then
513   // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)514   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
515     VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
516             << " op=" << group.root_node->op()
517             << " num_optimized_nodes=" << group.optimized_nodes.size()
518             << " num_inputs=" << group.inputs.size();
519 
520     // Do not optimize any of the nodes that are part of this group.
521     MarkAllMembersWithTag(group, kAddOpsRewriteTag);
522 
523     // All new nodes will be placed under the scope of a root node.
524     auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
525 
526     // Find what shapes are present in the inputs of absorbed nodes.
527     std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
528     for (const auto& input : group.inputs) {
529       shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
530     }
531 
532     using SigKV = decltype(shape_sig_to_inputs)::value_type;
533     VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
534             << " unique shapes: "
535             << str_util::Join(shape_sig_to_inputs, ", ",
536                               [](string* out, SigKV p) {
537                                 strings::StrAppend(out, p.first);
538                               });
539 
540     // Collect all the shapes from representative elements.
541     std::vector<TensorShapeProto> shapes;
542     shapes.reserve(shape_sig_to_inputs.size());
543     for (const auto& el : shape_sig_to_inputs)
544       shapes.push_back(el.second[0].shape);
545 
546     // If all inputs have the same shape, rewrite whole group with a single AddN
547     if (shapes.size() == 1) {
548       string node_name = UniqueOptimizedNodeName(root_scope_and_name);
549       AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
550                                         group.inputs);
551       return node_name;
552     }
553 
554     // For inputs of different shapes:
555     // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
556     // 2. Build a tree of Add nodes, minimizing cost of broadcast
557     std::sort(shapes.begin(), shapes.end(),
558               [](const TensorShapeProto& left, const TensorShapeProto& right) {
559                 return CompareSymbolicallyShapedTensorSizes(left, right);
560               });
561 
562     // optimized name for leaf AddN nodes
563     auto leaf_node_name = [&root_scope_and_name, this](int i) {
564       return UniqueOptimizedNodeName(root_scope_and_name,
565                                      strings::StrCat("Leaf_", i));
566     };
567     // optimized name for internal nodes of a tree built up from AddN leaves
568     auto internal_node_name = [&root_scope_and_name, this](int i) {
569       return UniqueOptimizedNodeName(root_scope_and_name,
570                                      strings::StrCat("Internal_", i));
571     };
572 
573     // Add/AddN nodes that must be added to the tree
574     std::deque<InputAndShape> add_ops;
575 
576     // Prepare leaf AddN nodes for inputs of equal shape
577     for (int i = 0; i < shapes.size(); ++i) {
578       const auto node_name = leaf_node_name(i);
579       const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
580       add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
581                                                           node_name, inputs));
582     }
583 
584     // Build up a tree of Add ops
585     int internal_nodes = 0;
586     do {
587       const InputAndShape lhs = add_ops.front();
588       add_ops.pop_front();
589       const InputAndShape rhs = add_ops.front();
590       add_ops.pop_front();
591       string name = add_ops.empty()
592                         ? UniqueOptimizedNodeName(root_scope_and_name)
593                         : internal_node_name(internal_nodes++);
594       InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
595       add_ops.push_front(add);
596     } while (add_ops.size() > 1);
597 
598     InputAndShape optimized_root_node = add_ops.front();
599     return optimized_root_node.input;
600   }
601 
602   // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)603   InputAndShape AddInputsOfSymbolicallyEqualShape(
604       const NodeDef& root_node, const string& node_name,
605       const std::vector<InputAndShape>& inputs) {
606     CHECK(!inputs.empty()) << "Inputs must be non-empty";
607 
608     // Do not create redundant AddN nodes
609     if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
610       return inputs[0];
611     }
612 
613     // get shape from representative element
614     auto shape = inputs[0].shape;
615 
616     // copy attributes from a root node
617     DataType dtype = root_node.attr().at("T").type();
618 
619     // add new AddN node
620     NodeDef* node = AddEmptyNode(node_name);
621     node->set_op("AddN");
622     node->set_device(root_node.device());
623     (*node->mutable_attr())["T"].set_type(dtype);
624     (*node->mutable_attr())["N"].set_i(inputs.size());
625 
626     for (const auto& inputAndShape : inputs) {
627       ctx().node_map->AddOutput(inputAndShape.input, node_name);
628       node->add_input(inputAndShape.input);
629     }
630 
631     MarkWithTag(kAddOpsRewriteTag, node);
632     return InputAndShape(node_name, shape);
633   }
634 
635   // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)636   InputAndShape AddAggregatedInputs(const NodeDef& root_node,
637                                     const string& node_name,
638                                     const InputAndShape& left,
639                                     const InputAndShape& right) {
640     // copy attributes from a root node
641     DataType dtype = root_node.attr().at("T").type();
642 
643     // add new Add node
644     NodeDef* node = AddEmptyNode(node_name);
645     node->set_op("Add");
646     node->set_device(root_node.device());
647     (*node->mutable_attr())["T"].set_type(dtype);
648     node->add_input(left.input);
649     node->add_input(right.input);
650 
651     ctx().node_map->AddOutput(left.input, node_name);
652     ctx().node_map->AddOutput(right.input, node_name);
653 
654     MarkWithTag(kAddOpsRewriteTag, node);
655     return InputAndShape(
656         node_name, TensorShapeProto());  // shape is not important at this point
657   }
658 };
659 
660 // Use the distributive property of multiplication and division over addition,
661 // along with commutativity of the former, to hoist common factors/denominators
662 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
663 // This pattern occurs frequently in regularization terms for the gradients
664 // during training.
665 //
666 // For example, we can rewrite an expression of the form:
667 //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
668 // to the following:
669 //   Mul(x, AddN(y1, y2, y3, ... yn))
670 // For division, we can rewrite
671 //   AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
672 // to:
673 //   Div(AddN(y1, y2, y3, ... yn), x)
674 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
675  public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)676   explicit HoistCommonFactorOutOfAggregation(
677       const GraphOptimizerContext& ctx,
678       const ArithmeticOptimizerContext& ctx_ext)
679       : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
680   ~HoistCommonFactorOutOfAggregation() override = default;
681 
IsSupported(const NodeDef * node) const682   bool IsSupported(const NodeDef* node) const override {
683     return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
684            !IsRewritten(node);
685   }
686 
TrySimplify(NodeDef * node,string * simplified_node_name)687   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
688     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
689 
690     bool common_factor_is_denominator = false;
691     std::set<string> common_factors;
692     std::vector<string> ctrl_deps;
693     TF_RETURN_IF_ERROR(GetCommonFactors(
694         node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
695 
696     if (common_factors.size() == 1) {
697       const string& common_factor = *common_factors.begin();
698 
699       // Gather up the non-shared factors
700       bool shapes_match = true;
701       std::vector<string> unique_factors;
702       TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
703                                           common_factor_is_denominator,
704                                           &shapes_match, &unique_factors));
705 
706       if (shapes_match) {
707         NodeDef* input_0;
708         TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
709 
710         // Use a copy of the first node for the outer multiplication/division.
711         NodeDef* new_outer_node = AddCopyNode(
712             OuterNodeName(node, common_factor_is_denominator), input_0);
713         // And a copy of aggregation node as one of the inner operands
714         NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
715 
716         new_outer_node->set_device(node->device());
717         if (common_factor_is_denominator) {
718           new_outer_node->set_input(0, new_add_node->name());
719           new_outer_node->set_input(1, common_factor);
720         } else {
721           new_outer_node->set_input(0, common_factor);
722           new_outer_node->set_input(1, new_add_node->name());
723         }
724 
725         ctx().node_map->AddOutput(common_factor, new_outer_node->name());
726         ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
727 
728         // Hoist non-shared factors up into the new AddN node.
729         for (int i = 0; i < unique_factors.size(); ++i) {
730           const string& unique_factor_i = unique_factors[i];
731           new_add_node->set_input(i, unique_factor_i);
732           ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
733         }
734 
735         // Add control deps on add node
736         for (const string& ctrl_dep : ctrl_deps) {
737           *new_add_node->add_input() = ctrl_dep;
738           ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
739         }
740 
741         // optimize new inner aggregation node
742         AddToOptimizationQueue(new_add_node);
743         // do not optimize the same node twice
744         rewritten_nodes_.insert(node->name());
745         *simplified_node_name = new_outer_node->name();
746       }
747     }
748     return Status::OK();
749   }
750 
751  private:
752   // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const753   string OuterNodeName(const NodeDef* node, bool is_div) const {
754     auto scope_and_name = ParseNodeScopeAndName(node->name());
755     return is_div ? OptimizedNodeName(scope_and_name, "Div")
756                   : OptimizedNodeName(scope_and_name, "Mul");
757   }
758 
759   // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const760   string InnerAddNodeName(const NodeDef* node) const {
761     auto scope_and_name = ParseNodeScopeAndName(node->name());
762     return OptimizedNodeName(scope_and_name, "Add");
763   }
764 
765   // Determine the set of common factors if the input nodes are all Mul or
766   // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const767   Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
768                           bool* common_factor_is_denominator,
769                           std::vector<string>* ctrl_deps) const {
770     CHECK(common_factors->empty());
771     CHECK_NOTNULL(common_factor_is_denominator);
772     *common_factor_is_denominator = false;
773 
774     bool has_mul = false;
775     bool has_div = false;
776     for (int i = 0; i < node->input_size(); ++i) {
777       if (i > 0 && common_factors->empty()) break;
778       if (IsControlInput(node->input(i))) {
779         ctrl_deps->push_back(node->input(i));
780         continue;
781       }
782       NodeDef* input;
783       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
784 
785       if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
786           (IsAnyDiv(*input) && has_mul)) {
787         // Break if input is neither a Mul or Div, or if there are both Mul &
788         // Div Ops.
789         common_factors->clear();
790         break;
791       } else if (IsAnyDiv(*input)) {
792         has_div = true;
793         // In case of possible common dividers, we avoid hoisting out if any
794         // input is not float/double, since integer division is not distributive
795         // over addition.
796         OpInfo::TensorProperties properties0, properties1;
797         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
798         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
799         if (properties0.dtype() != DT_FLOAT &&
800             properties0.dtype() != DT_DOUBLE &&
801             properties1.dtype() != DT_FLOAT &&
802             properties1.dtype() != DT_DOUBLE) {
803           common_factors->clear();
804           break;
805         }
806       } else if (IsMul(*input)) {
807         has_mul = true;
808       }
809 
810       // We only focus on common factors from denominators if any Op is a
811       // Div.
812       std::set<string> factors_i =
813           has_mul ? std::set<string>{input->input(0), input->input(1)}
814                   : std::set<string>{input->input(1)};
815       if (i == 0) {
816         std::swap(*common_factors, factors_i);
817       } else {
818         std::set<string> intersection;
819         std::set_intersection(
820             factors_i.begin(), factors_i.end(), common_factors->begin(),
821             common_factors->end(),
822             std::inserter(intersection, intersection.begin()));
823         std::swap(*common_factors, intersection);
824       }
825       for (int i = 2; i < input->input_size(); ++i) {
826         ctrl_deps->push_back(input->input(i));
827       }
828     }
829 
830     *common_factor_is_denominator = has_div;
831     return Status::OK();
832   }
833 
834   // Gather up the non-shared factors (the y's in the example).
835   // Unless the aggregation is Add, we have to make sure that all the y's
836   // have the same shape since the other aggregation ops do not support
837   // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const838   Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
839                           const bool common_factor_is_denominator,
840                           bool* shapes_match,
841                           std::vector<string>* unique_factors) const {
842     *shapes_match = true;
843     unique_factors->reserve(node->input_size());
844 
845     for (int i = 0; i < node->input_size() && shapes_match; ++i) {
846       const string& input = node->input(i);
847       if (IsControlInput(input)) {
848         break;
849       }
850       NodeDef* inner_node;
851       TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
852       const int unique_factor_index =
853           common_factor_is_denominator
854               ? 0
855               : (inner_node->input(0) == common_factor ? 1 : 0);
856       unique_factors->push_back(inner_node->input(unique_factor_index));
857       if (i > 0 && !IsAdd(*node)) {
858         OpInfo::TensorProperties lhs;
859         OpInfo::TensorProperties rhs;
860         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
861         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
862         *shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
863       }
864     }
865     return Status::OK();
866   }
867 
IsRewritten(const NodeDef * node) const868   bool IsRewritten(const NodeDef* node) const {
869     // if graph rewrite happens in multiple passes without graph pruning between
870     // them, it's possible that rewritten node already exists in a graph
871     return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
872            ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
873            ctx().node_map->NodeExists(OuterNodeName(node, true));
874   }
875 
876   // keep names of the nodes that were optimized by this stage
877   std::unordered_set<string> rewritten_nodes_;
878 };
879 
880 // Binary associative ops can be re-ordered to minimize the number of broadcasts
881 // and the size of a temporary tensors.
882 //
883 // Example: [a, c] - scalars, [b, d] - matrices
884 //   @ - binary associative op (Add or Mul)
885 //   @* - broadcast
886 //
887 //           @                      @*
888 //        /     \                /      \
889 //      @*       @*      ->     @        @
890 //    /   \    /   \          /   \    /   \
891 //   a     b  c     d        a     c  b     d
892 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
893  public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)894   explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
895                               const ArithmeticOptimizerContext& ctx_ext)
896       : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
897   }
898   ~MinimizeBroadcasts() override = default;
899 
IsSupported(const NodeDef * node) const900   bool IsSupported(const NodeDef* node) const override {
901     if (!IsBinaryAssociative(*node)) return false;
902 
903     if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
904       return false;
905 
906     // has a symbolically defined shape with broadcastable inputs
907     OpInfo::TensorProperties properties;
908     Status has_properties = GetTensorProperties(node->name(), &properties);
909     return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
910            HasAllInputsBroadcastableToShape(*node, properties);
911   }
912 
913  protected:
IsBinaryAssociative(const NodeDef & node) const914   bool IsBinaryAssociative(const NodeDef& node) const {
915     return IsMul(node) || IsAdd(node);
916   }
917 
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const918   bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
919     return group.root_node->op() == node.op();
920   }
921 
922   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const923   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
924                                          const NodeDef& node) const override {
925     if (!IsSameOp(group, node)) {
926       return false;
927     }
928     if (IsInPreserveSet(node)) {
929       return false;
930     }
931     // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
932     if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
933       return false;
934     }
935     if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
936       return false;
937     }
938     if (!IsOnTheSameDevice(group, node)) {
939       return false;
940     }
941     // Optimized nodes updated in place, and that would break the graph, if the
942     // node has multiple output consumers
943     if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
944       return false;
945     }
946     // All input shapes must be broadcastable to the node shape
947     OpInfo::TensorProperties properties;
948     Status has_properties = GetTensorProperties(node.name(), &properties);
949     return has_properties.ok() &&
950            HasAllInputsBroadcastableToShape(node, properties);
951   }
952 
CountUniqueShapes(const std::vector<InputAndShape> & inputs)953   std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
954     std::set<string> sigs;
955     for (const auto& ias : inputs) {
956       sigs.insert(ShapeSignature(ias.shape));
957     }
958     return sigs.size();
959   }
960 
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)961   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
962     VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
963             << " op=" << group.root_node->op()
964             << " num_optimized_nodes=" << group.optimized_nodes.size();
965 
966     // Do not optimize any of the nodes that are part of this group.
967     MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
968 
969     if (CountUniqueShapes(group.inputs) <= 1) {
970       VLOG(3) << "Skip min-bcast group with single unique shape";
971       // nothing to optimize when all shapes are the same
972       return group.root_node->name();
973     }
974 
975     auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
976     auto num_inputs = group.inputs.size();
977     CHECK_EQ(num_nodes, num_inputs - 1)
978         << "Can't build a tree with " << num_inputs << " inputs, using "
979         << num_nodes << "binary op nodes.";
980 
981     std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
982     std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
983                                          group.optimized_nodes.end());
984 
985     // sort inputs by it's shape from smallest to largest
986     std::stable_sort(add_ops.begin(), add_ops.end(),
987                      [](const InputAndShape& lhs, const InputAndShape& rhs) {
988                        return CompareSymbolicallyShapedTensorSizes(lhs.shape,
989                                                                    rhs.shape);
990                      });
991 
992     // If there is an odd number of inputs, last one is the largest, and we want
993     // to attach it to the root node, to build a well balanced tree.
994     std::deque<InputAndShape> add_ops_leftover;
995     if (add_ops.size() % 2 != 0) {
996       add_ops_leftover.push_back(add_ops.back());
997       add_ops.pop_back();
998     }
999 
1000     // At this point it's guaranteed that add_ops have even number of inputs.
1001     do {
1002       const InputAndShape lhs = add_ops.front();
1003       add_ops.pop_front();
1004       const InputAndShape rhs = add_ops.front();
1005       add_ops.pop_front();
1006 
1007       NodeDef* node;
1008       if (!optimized_nodes.empty()) {
1009         // re-purpose optimized nodes to build a new tree
1010         node = optimized_nodes.back();
1011         optimized_nodes.pop_back();
1012       } else {
1013         // or use root node if none optimized nodes left
1014         node = group.root_node;
1015       }
1016       InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1017 
1018       // Pushing updated node to the back of a deque will create a wide and
1019       // short tree, pushing to the front will create a tall tree. We prefer to
1020       // get a wide tree, it minimizes the potential number of temporary tensors
1021       // required to keep in memory, though sometimes we can go up to prevent
1022       // propagating a brodcast from leaves to the root. Example:
1023       //
1024       // inputs: [s, s, s, M] (s - scalar, M - matrix)
1025       // @* - op with broadcast
1026       //
1027       //  (only push_back)           @*     (push_front first op)
1028       //                            /  \
1029       //       @*                  @    M
1030       //     /   \                / \
1031       //    @     @*      ->     @   s
1032       //   / \   / \            / \
1033       //  s   s s   M          s   s
1034       if (add_ops.size() >= 2 &&
1035           CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1036                                                add_ops.at(1).shape)) {
1037         add_ops.push_front(updated_node);
1038       } else {
1039         add_ops.push_back(updated_node);
1040       }
1041     } while (add_ops.size() > 1);
1042     CHECK_EQ(1, add_ops.size());
1043 
1044     // attach the largest tensor to the root op
1045     if (!add_ops_leftover.empty()) {
1046       const InputAndShape lhs = add_ops.front();
1047       add_ops.pop_front();
1048       const InputAndShape rhs = add_ops_leftover.front();
1049       InputAndShape updated_node =
1050           UpdateInputs(lhs.input, rhs.input, group.root_node);
1051       add_ops.push_back(updated_node);
1052     }
1053 
1054     return add_ops.front().input;
1055   }
1056 
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1057   InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1058                              NodeDef* node) {
1059     string old_input_0 = node->input(0);
1060     string old_input_1 = node->input(1);
1061 
1062     // Update inputs only if they changed
1063     if (old_input_0 != input_0 || old_input_1 != input_1) {
1064       node->set_input(0, input_0);
1065       node->set_input(1, input_1);
1066       // Invalidate node properties (shape)
1067       ctx().graph_properties->ClearOutputProperties(node->name());
1068       ctx().graph_properties->ClearInputProperties(node->name());
1069       // Update the node map
1070       ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1071       ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1072       ctx().node_map->AddOutput(NodeName(input_0), node->name());
1073       ctx().node_map->AddOutput(NodeName(input_1), node->name());
1074       // Add updated node to optimization queue
1075       AddToOptimizationQueue(node);
1076     }
1077 
1078     TensorShapeProto shape;  // shape is not important at this point
1079     return InputAndShape(node->name(), shape);
1080   }
1081 };
1082 
1083 // Removes inverse transpose nodes
1084 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1085  public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1086   explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1087                                    const ArithmeticOptimizerContext& ctx_ext)
1088       : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1089   ~RemoveIdentityTranspose() override = default;
1090 
IsSupported(const NodeDef * node) const1091   bool IsSupported(const NodeDef* node) const override {
1092     return IsTranspose(*node) || IsConjugateTranspose(*node);
1093   }
1094 
TrySimplify(NodeDef * node,string * simplified_node_name)1095   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1096     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1097     NodeDef* tail = node;
1098     tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1099                                     *ctx().nodes_to_preserve);
1100     NodeDef* first_transpose;
1101     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1102 
1103     NodeDef* node_perm;
1104     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1105     if (!IsConstant(*node_perm)) {
1106       return Status::OK();
1107     }
1108     std::vector<int64> node_perm_values;
1109     TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1110     if (first_transpose->op() == node->op()) {
1111       // Remove pairs of transposes that cancel each other.
1112       NodeDef* first_transpose_perm;
1113       TF_RETURN_IF_ERROR(
1114           GetInputNode(first_transpose->input(1), &first_transpose_perm));
1115       if (!IsConstant(*first_transpose_perm)) {
1116         return Status::OK();
1117       }
1118       std::vector<int64> first_transpose_perm_values;
1119       TF_RETURN_IF_ERROR(
1120           GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1121       if (AreInversePermutations(node_perm_values,
1122                                  first_transpose_perm_values)) {
1123         if (tail == node) {
1124           // Bypass adjacent pair.
1125           *simplified_node_name = first_transpose->input(0);
1126         } else {
1127           // Bypass pair connected through chain.
1128           tail->set_input(0, first_transpose->input(0));
1129           ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1130                                       first_transpose->input(0));
1131           ForwardControlDependencies(tail, {first_transpose});
1132           *simplified_node_name = node->input(0);
1133         }
1134       }
1135     } else {
1136       // Remove simple identity transposes.
1137       if (IsIdentityPermutation(node_perm_values)) {
1138         *simplified_node_name = node->input(0);
1139       }
1140     }
1141     return Status::OK();
1142   }
1143 
1144  private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1145   Status GetPermutation(const NodeDef& node_perm,
1146                         std::vector<int64>* perm64) const {
1147     std::vector<int> perm32;
1148     if (ValuesFromConstNode(node_perm, &perm32)) {
1149       perm64->reserve(perm32.size());
1150       for (int val : perm32) {
1151         perm64->push_back(static_cast<int64>(val));
1152       }
1153       return Status::OK();
1154     }
1155     if (ValuesFromConstNode(node_perm, perm64)) {
1156       return Status::OK();
1157     }
1158     return errors::InvalidArgument("Couldn't extract permutation from ",
1159                                    node_perm.name());
1160   }
1161 
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1162   bool AreInversePermutations(const std::vector<int64>& a,
1163                               const std::vector<int64>& b) {
1164     if (a.size() != b.size()) {
1165       return false;
1166     }
1167     for (int i = 0; i < a.size(); ++i) {
1168       if (a[b[i]] != i) {
1169         return false;
1170       }
1171     }
1172     return true;
1173   }
1174 
IsIdentityPermutation(const std::vector<int64> & perm)1175   bool IsIdentityPermutation(const std::vector<int64>& perm) {
1176     for (int64 i = 0; i < perm.size(); ++i) {
1177       if (i != perm[i]) {
1178         return false;
1179       }
1180     }
1181     return true;
1182   }
1183 };
1184 
1185 // An involution is an element-wise function f(x) that is its own inverse,
1186 // i.e. f(f(x)) = x. If we can find a chain of ops
1187 //   f->op1->op2->...opn->f
1188 // where op1 through opn preserve the values of their inputs, we can remove
1189 // the two instances of the involution from the graph, since they cancel
1190 // each other.
1191 class RemoveInvolution : public ArithmeticOptimizerStage {
1192  public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1193   explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1194                             const ArithmeticOptimizerContext& ctx_ext)
1195       : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1196   ~RemoveInvolution() override = default;
1197 
IsSupported(const NodeDef * node) const1198   bool IsSupported(const NodeDef* node) const override {
1199     return IsInvolution(*node);
1200   }
1201 
TrySimplify(NodeDef * node,string * simplified_node_name)1202   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1203     NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1204                                                   *ctx().nodes_to_preserve);
1205 
1206     NodeDef* involution;
1207     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1208 
1209     if (involution->op() == node->op()) {
1210       // Skip both *node and *involution since they cancel each other.
1211       if (tail == node) {
1212         // The two nodes to eliminate are adjacent.
1213         *simplified_node_name = involution->input(0);
1214       } else {
1215         tail->set_input(0, involution->input(0));
1216         ctx().node_map->UpdateInput(tail->name(), involution->name(),
1217                                     involution->input(0));
1218         *simplified_node_name = node->input(0);
1219       }
1220     }
1221 
1222     return Status::OK();
1223   }
1224 };
1225 
1226 // Remove redundant Bitcasts.
1227 // 1) Remove Bitcast whose source type and destination type are equal
1228 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1229 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1230  public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1231   explicit RemoveRedundantBitcastStage(
1232       const GraphOptimizerContext& ctx,
1233       const ArithmeticOptimizerContext& ctx_ext)
1234       : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1235   ~RemoveRedundantBitcastStage() override = default;
1236 
IsSupported(const NodeDef * node) const1237   bool IsSupported(const NodeDef* node) const override {
1238     return IsBitcast(*node);
1239   }
1240 
TrySimplify(NodeDef * node,string * simplified_node_name)1241   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1242     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1243 
1244     // Bypass Bitcast whose source type and destination type are equal.
1245     AttrSlice attrs(*node);
1246     DataType input_type;
1247     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1248     DataType output_type;
1249     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1250     if (input_type == output_type) {
1251       *simplified_node_name = node->input(0);
1252       return Status::OK();
1253     }
1254 
1255     NodeDef* bitcast;
1256     TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1257     NodeDef* operand;
1258     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1259 
1260     if (IsBitcast(*operand)) {
1261       AttrSlice operand_attrs(*operand);
1262       DataType operand_input_type;
1263       TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1264       // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1265       bitcast->set_input(0, operand->input(0));
1266       SetDataTypeToAttr(operand_input_type, "T", bitcast);
1267       ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1268                                   operand->input(0));
1269       AddToOptimizationQueue(bitcast);
1270       *simplified_node_name = bitcast->name();
1271     }
1272 
1273     return Status::OK();
1274   }
1275 };
1276 
1277 // Remove Casts whose source type and destination type are equal.
1278 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1279  public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1280   explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1281                                     const ArithmeticOptimizerContext& ctx_ext)
1282       : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1283   ~RemoveRedundantCastStage() override = default;
1284 
IsSupported(const NodeDef * node) const1285   bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
1286 
TrySimplify(NodeDef * node,string * simplified_node_name)1287   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1288     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1289 
1290     // Bypass Cast whose source type and destination type are equal.
1291     AttrSlice attrs(*node);
1292     DataType input_type;
1293     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1294     DataType output_type;
1295     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1296     if (input_type == output_type) {
1297       *simplified_node_name = node->input(0);
1298     }
1299     return Status::OK();
1300   }
1301 };
1302 
1303 class RemoveNegationStage : public ArithmeticOptimizerStage {
1304  public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1305   explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1306                                const ArithmeticOptimizerContext& ctx_ext)
1307       : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1308   ~RemoveNegationStage() override = default;
1309 
IsSupported(const NodeDef * node) const1310   bool IsSupported(const NodeDef* node) const override {
1311     return IsAdd(*node) || IsSub(*node);
1312   }
1313 
TrySimplify(NodeDef * node,string * simplified_node_name)1314   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1315     NodeDef* x;
1316     NodeDef* y;
1317     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1318     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1319     bool updated = false;
1320     if (IsNeg(*y)) {
1321       // a - (-b) = a + b or  a + (-b) = a - b
1322       ForwardControlDependencies(node, {y});
1323       ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1324       node->set_op(IsAdd(*node) ? "Sub" : "Add");
1325       node->set_input(1, y->input(0));
1326       updated = true;
1327     } else if (IsAdd(*node) && IsNeg(*x)) {
1328       // (-a) + b = b - a
1329       ForwardControlDependencies(node, {x});
1330       ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1331       node->set_op("Sub");
1332       node->mutable_input()->SwapElements(0, 1);
1333       node->set_input(1, x->input(0));
1334       updated = true;
1335     }
1336     if (updated) {
1337       AddToOptimizationQueue(node);
1338     }
1339     return Status::OK();
1340   }
1341 };
1342 
1343 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1344  public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1345   explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1346                                  const ArithmeticOptimizerContext& ctx_ext)
1347       : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1348   ~RemoveLogicalNotStage() override = default;
1349 
IsSupported(const NodeDef * node) const1350   bool IsSupported(const NodeDef* node) const override {
1351     return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1352   }
1353 
TrySimplify(NodeDef * node,string * simplified_node_name)1354   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1355     const string node_name = node->name();
1356     NodeDef* input;
1357     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1358     if (IsInPreserveSet(*input) ||
1359         NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1360       return Status::OK();
1361     }
1362     string new_op;
1363     if (IsEqual(*input)) {
1364       new_op = "NotEqual";
1365     } else if (IsNotEqual(*input)) {
1366       new_op = "Equal";
1367     } else if (IsLess(*input)) {
1368       new_op = "GreaterEqual";
1369     } else if (IsLessEqual(*input)) {
1370       new_op = "Greater";
1371     } else if (IsGreater(*input)) {
1372       new_op = "LessEqual";
1373     } else if (IsGreaterEqual(*input)) {
1374       new_op = "Less";
1375     }
1376     if (!new_op.empty()) {
1377       input->set_op(new_op);
1378       *simplified_node_name = input->name();
1379     }
1380     return Status::OK();
1381   }
1382 };
1383 
1384 // This optimization hoists the common prefix of unary ops of the inputs to
1385 // concat out of the concat, for example:
1386 //    Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1387 // becomes
1388 //    Exp(Sin(Concat([x, y, z]))).
1389 // Similarly, it will hoist the common postfix of unary ops into Split or
1390 // SplitV nodes, for example:
1391 //    [Exp(Sin(y)) for y in Split(x)]
1392 // becomes
1393 //    [y for y in Split(Exp(Sin(x))]
1394 //
1395 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1396 // on the concat/split node.
1397 // TODO(rmlarsen): Handle Enter/Exit.
1398 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1399  public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1400   explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1401                                       const ArithmeticOptimizerContext& ctx_ext)
1402       : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1403 
1404   ~HoistCWiseUnaryChainsStage() override = default;
1405 
1406   struct ChainLink {
1407     ChainLink() = default;
ChainLinktensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1408     ChainLink(NodeDef* _node, int _port_origin)
1409         : node(_node), port_origin(_port_origin) {}
1410     NodeDef* node;    // Node in a chain.
1411     int port_origin;  // Port on concat/split node from which this chain
1412                       // originates.
1413 
operator <tensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1414     bool operator<(const ChainLink& other) const {
1415       if (port_origin < other.port_origin) {
1416         return true;
1417       } else if (port_origin > other.port_origin) {
1418         return false;
1419       } else {
1420         return node->name() < other.node->name();
1421       }
1422     }
1423   };
1424 
1425   // We use an ordinary set sorted on port and node name, so the order, and
1426   // hence the node name used for the hoisted chain, will be deterministic.
1427   using ChainLinkSet = std::set<ChainLink>;
1428 
IsSupported(const NodeDef * node) const1429   bool IsSupported(const NodeDef* node) const override {
1430     if (IsInPreserveSet(*node)) return false;
1431     if (IsConcat(*node) && node->attr().count("N") != 0) {
1432       const int n = node->attr().at("N").i();
1433       return n > 1;
1434     } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1435                node->attr().count("num_split") != 0) {
1436       const int num_split = node->attr().at("num_split").i();
1437       if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1438         // TODO(rmlarsen): Remove this constraint when we have optimizations
1439         // in place for merging slices into splits.
1440         return false;
1441       }
1442       return num_split > 1 && !IsAlreadyOptimized(*node);
1443     }
1444     return false;
1445   }
1446 
TrySimplify(NodeDef * node,string * simplified_node_name)1447   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1448     node_is_concat_ = IsConcat(*node);
1449     int prefix_length;
1450     std::set<string> ctrl_inputs;
1451     ChainLinkSet tails;
1452     TF_RETURN_IF_ERROR(
1453         FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1454     if (prefix_length > 0 && !tails.empty()) {
1455       TF_RETURN_IF_ERROR(
1456           HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1457     }
1458     return Status::OK();
1459   }
1460 
1461  private:
1462   // Returns the length of the common unary chain of ops that can be
1463   // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1464   Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1465                                 ChainLinkSet* tails,
1466                                 std::set<string>* ctrl_inputs) const {
1467     *prefix_length = 0;
1468     // Follow the chains starting at each concat input or split output as long
1469     // as all the following conditions hold:
1470     //   1. The ops in all chains are the same.
1471     //   2. The ops are unary elemenwise op.
1472     //   3. The op output has only a single consumer (concat only).
1473     ChainLinkSet cur_tails;
1474     TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1475     if (cur_tails.size() < 2) {
1476       return Status::OK();
1477     }
1478     ctrl_inputs->clear();
1479     bool stop = false;
1480     while (!stop && !cur_tails.empty() &&
1481            OpsAreSafeToHoist(root_node, cur_tails)) {
1482       // We found one more link that can be hoisted.
1483       ++(*prefix_length);
1484       tails->swap(cur_tails);
1485       GatherControlInputs(ctrl_inputs, *tails);
1486 
1487       // Advance tail pointers to the next level.
1488       TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1489     }
1490     return Status::OK();
1491   }
1492 
1493   // Hoists the chains to the other side of concat or split and attaches the
1494   // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1495   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1496                            std::set<string>* ctrl_inputs, NodeDef* root_node) {
1497     if (tails.empty()) {
1498       return Status::OK();
1499     }
1500     AddToOptimizationQueue(root_node);
1501     optimized_nodes_.insert(root_node->name());
1502     if (node_is_concat_) {
1503       AddControlInputs(ctrl_inputs, root_node);
1504       return HoistChainForConcat(prefix_length, tails, root_node);
1505     } else {
1506       return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1507     }
1508   }
1509 
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1510   void GatherControlInputs(std::set<string>* ctrl_inputs,
1511                            const ChainLinkSet& ops) const {
1512     for (const auto& link : ops) {
1513       const NodeDef* node = link.node;
1514       for (int i = node->input_size() - 1; i >= 0; --i) {
1515         const string& input = node->input(i);
1516         if (!IsControlInput(input)) break;
1517         ctrl_inputs->insert(input);
1518       }
1519     }
1520   }
1521 
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1522   void AddControlInputs(std::set<string>* new_ctrl_inputs,
1523                         NodeDef* node) const {
1524     for (int i = node->input_size() - 1; i >= 0; --i) {
1525       const string& existing_input = node->input(i);
1526       if (!IsControlInput(existing_input)) break;
1527       new_ctrl_inputs->erase(existing_input);
1528     }
1529     for (const string& new_input : *new_ctrl_inputs) {
1530       ctx().node_map->AddOutput(NodeName(new_input), node->name());
1531       node->add_input(new_input);
1532     }
1533   }
1534 
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1535   Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1536     if (node_is_concat_) {
1537       // Handle concat nodes by looking backwards in the graph.
1538       TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1539       const int n = node.attr().at("N").i();
1540       const int start = node.op() == "Concat" ? 1 : 0;
1541       const int end = start + n;
1542       // Set up tail pointers to point to the immediate inputs to Concat.
1543       for (int input_port = start; input_port < end; ++input_port) {
1544         if (IsControlInput(node.input(input_port))) {
1545           return errors::FailedPrecondition(
1546               "Got control input ", node.input(input_port),
1547               " where normal input was expected.");
1548         }
1549         NodeDef* tail;
1550         TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1551         tails->insert(ChainLink(tail, input_port));
1552       }
1553       return Status::OK();
1554     } else {
1555       // Handle split nodes by looking forwards in the graph.
1556       const auto& outputs = ctx().node_map->GetOutputs(node.name());
1557       for (NodeDef* output : outputs) {
1558         if (IsControlInput(output->input(0))) continue;
1559         TensorId tensor_id = ParseTensorName(output->input(0));
1560         if (tensor_id.node() == node.name()) {
1561           tails->insert(ChainLink(output, tensor_id.index()));
1562         } else {
1563           // This output node has a non-control input other than the split node,
1564           // abort.
1565           tails->clear();
1566           return Status::OK();
1567         }
1568       }
1569     }
1570     return Status::OK();
1571   }
1572 
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1573   bool OpsAreSafeToHoist(const NodeDef& root_node,
1574                          const ChainLinkSet& ops) const {
1575     if (ops.empty()) return true;
1576     const NodeDef* op0 = ops.begin()->node;
1577     if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1578     for (const auto& link : ops) {
1579       const NodeDef* op = link.node;
1580       if (op->device() != root_node.device() || op->op() != op0->op() ||
1581           IsInPreserveSet(*op)) {
1582         return false;
1583       }
1584       if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1585         // TODO(rmlarsen): Allow outgoing control edges.
1586         return false;
1587       }
1588     }
1589     return true;
1590   }
1591 
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1592   Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1593                       bool* stop) const {
1594     *stop = true;
1595     new_tails->clear();
1596     for (const auto& link : tails) {
1597       const NodeDef* tail = link.node;
1598       if (node_is_concat_) {
1599         if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1600           return Status::OK();
1601         }
1602         NodeDef* new_tail;
1603         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1604         // Remember original port.
1605         new_tails->insert(ChainLink(new_tail, link.port_origin));
1606       } else {
1607         for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1608           const TensorId tensor = ParseTensorName(new_tail->input(0));
1609           if (tensor.node() != tail->name()) {
1610             return Status::OK();
1611           }
1612           // Skip control outputs.
1613           if (tensor.index() >= 0) {
1614             // Remember original port.
1615             new_tails->insert(ChainLink(new_tail, link.port_origin));
1616           }
1617         }
1618       }
1619     }
1620     *stop = false;
1621     return Status::OK();
1622   }
1623 
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1624   Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1625                              NodeDef* concat_node) {
1626     const string& concat_name = concat_node->name();
1627     const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1628     for (const auto& link : tails) {
1629       NodeDef* tail = CHECK_NOTNULL(link.node);
1630       const int concat_port = link.port_origin;
1631       CHECK_GE(concat_port, 0);
1632       CHECK_LT(concat_port, concat_node->input_size());
1633       const string concat_input = concat_node->input(concat_port);
1634       // Hook the node following tail directly into the concat node.
1635       const string tail_input = tail->input(0);
1636       concat_node->set_input(concat_port, tail_input);
1637       ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1638 
1639       if (concat_port == first_input) {
1640         // Update the consumers of concat to consume the end of the chain
1641         // instead.
1642         UpdateConsumers(concat_node, concat_input);
1643         // Reuse nodes in the first chain to process output of concat.
1644         tail->set_input(0, concat_name);
1645         ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1646       }
1647     }
1648     return Status::OK();
1649   }
1650 
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1651   Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1652                             std::set<string>* ctrl_inputs,
1653                             NodeDef* split_node) {
1654     // Create a new chain before the split node to process the input tensor.
1655     const string& split_name = split_node->name();
1656     auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1657 
1658     // We use the first tail node in the set as a template to get the list of
1659     // ops to apply (starting from the end).
1660     NodeDef* cur_tail = tails.begin()->node;
1661     NodeDef* cur_copy = AddCopyNode(
1662         OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1663     cur_copy->clear_input();
1664 
1665     // Update the split to take its input from the tail of the new chain.
1666     const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1667     const string orig_input = split_node->input(value_slot);
1668     split_node->set_input(value_slot, cur_copy->name());
1669     ctx().node_map->UpdateInput(split_node->name(), orig_input,
1670                                 cur_copy->name());
1671     TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1672 
1673     // Now walk backwards creating the rest of the chain.
1674     while (cur_tail != split_node) {
1675       NodeDef* new_copy = AddCopyNode(
1676           OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1677       new_copy->clear_input();
1678       cur_copy->add_input(new_copy->name());
1679       ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1680       cur_copy = new_copy;
1681       TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1682     }
1683     // Connect the original input to the head of the new chain.
1684     cur_copy->add_input(orig_input);
1685     ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1686                                  cur_copy->name());
1687     // Make sure all the control inputs are satisfied before running the first
1688     // node in the new chain.
1689     AddControlInputs(ctrl_inputs, cur_copy);
1690 
1691     // Connect all consumers of the tail nodes directly to the
1692     // output port of Split from which the chain started.
1693     for (const auto& link : tails) {
1694       UpdateConsumers(link.node,
1695                       link.port_origin == 0
1696                           ? split_name
1697                           : strings::StrCat(split_name, ":", link.port_origin));
1698     }
1699     return Status::OK();
1700   }
1701 
1702   // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)1703   void UpdateConsumers(NodeDef* node, const string& new_input) {
1704     const string& node_name = node->name();
1705     const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
1706     for (NodeDef* consumer : consumers) {
1707       for (int i = 0; i < consumer->input_size(); ++i) {
1708         if (consumer->input(i) == node_name) {
1709           consumer->set_input(i, new_input);
1710           ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
1711         }
1712       }
1713       AddToOptimizationQueue(consumer);
1714     }
1715   }
1716 
IsAlreadyOptimized(const NodeDef & node) const1717   bool IsAlreadyOptimized(const NodeDef& node) const {
1718     return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1719   }
1720 
1721  private:
1722   bool node_is_concat_;
1723   std::unordered_set<string> optimized_nodes_;
1724 };
1725 
1726 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1727  public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1728   explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1729                                  const ArithmeticOptimizerContext& ctx_ext)
1730       : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1731   ~RemoveIdempotentStage() override = default;
1732 
IsSupported(const NodeDef * node) const1733   bool IsSupported(const NodeDef* node) const override {
1734     return node->input_size() == 1 && IsIdempotent(*node) &&
1735            !IsInPreserveSet(*node);
1736   }
1737 
TrySimplify(NodeDef * node,string * simplified_node_name)1738   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1739     NodeDef* input;
1740     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1741     if (input->op() == node->op() && input->device() == node->device()) {
1742       *simplified_node_name = node->input(0);
1743     }
1744     return Status::OK();
1745   }
1746 };
1747 
1748 // Performs the conversion:
1749 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1750 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1751 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1752  public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1753   explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1754                                   const ArithmeticOptimizerContext& ctx_ext)
1755       : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1756   ~SqrtDivToRsqrtMulStage() override = default;
1757 
IsSupported(const NodeDef * node) const1758   bool IsSupported(const NodeDef* node) const override {
1759     return IsAnyDiv(*node);
1760   }
1761 
TrySimplify(NodeDef * node,string * simplified_node_name)1762   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1763     NodeDef* y;
1764     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1765     // Optimize only if divisor is a Sqrt whose output is not being consumed
1766     // elsewhere.
1767     if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1768         (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1769       // a / sqrt(b) = a * rsqrt(b)
1770       node->set_op("Mul");
1771       y->set_op("Rsqrt");
1772       AddToOptimizationQueue(node);
1773       AddToOptimizationQueue(y);
1774     }
1775     return Status::OK();
1776   }
1777 };
1778 
1779 // Performs the conversion:
1780 // Square(Sub(x, y)) => Identity(SquaredDifference(x, y))
1781 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1782  public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1783   explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1784                                 const ArithmeticOptimizerContext& ctx_ext)
1785       : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1786   ~FuseSquaredDiffStage() override = default;
1787 
IsSupported(const NodeDef * node) const1788   bool IsSupported(const NodeDef* node) const override {
1789     return IsSquare(*node);
1790   }
1791 
TrySimplify(NodeDef * node,string * simplified_node_name)1792   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1793     NodeDef* b;
1794     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1795     // Optimize only if base is a Sub whose output is not being consumed
1796     // elsewhere.
1797     if (IsSub(*b) && !IsInPreserveSet(*b) &&
1798         (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1799       node->set_op("Identity");
1800       b->set_op("SquaredDifference");
1801       AddToOptimizationQueue(node);
1802       AddToOptimizationQueue(b);
1803     }
1804     return Status::OK();
1805   }
1806 };
1807 
1808 // Performs the conversion:
1809 // Log(Softmax(x)) => LogSoftmax(x)
1810 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1811  public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1812   explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1813                            const ArithmeticOptimizerContext& ctx_ext)
1814       : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1815   ~LogSoftmaxStage() override = default;
1816 
IsSupported(const NodeDef * node) const1817   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1818 
TrySimplify(NodeDef * node,string * simplified_node_name)1819   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1820     NodeDef* x;
1821     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1822     // Optimize only if arg is a Softmax whose output is not being consumed
1823     // elsewhere.
1824     if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1825         (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1826       // Log(Softmax(x)) => LogSoftmax(Identity(x))
1827       node->set_op("LogSoftmax");
1828       x->set_op("Identity");
1829       AddToOptimizationQueue(node);
1830       AddToOptimizationQueue(x);
1831     }
1832     return Status::OK();
1833   }
1834 };
1835 
1836 // Bypass redundant reshape nodes:
1837 //
1838 //   Reshape                    Reshape  <-+
1839 //      ^                                  |
1840 //      |                                  |
1841 //   Reshape       becomes      Reshape    |
1842 //      ^                                  |
1843 //      |                                  |
1844 //    input                      input  ---+
1845 class RemoveRedundantReshape : public ArithmeticOptimizerStage {
1846  public:
RemoveRedundantReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1847   explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
1848                                   const ArithmeticOptimizerContext& ctx_ext)
1849       : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
1850   ~RemoveRedundantReshape() override = default;
1851 
IsSupported(const NodeDef * node) const1852   bool IsSupported(const NodeDef* node) const override {
1853     return IsReshape(*node);
1854   }
1855 
TrySimplify(NodeDef * node,string * simplified_node_name)1856   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1857     NodeDef* input;
1858     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1859 
1860     // 1. Bypass reshape followed by reshape.
1861     if (IsReshape(*input) && !HasControlInputs(*input)) {
1862       node->set_input(0, input->input(0));
1863       ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
1864       *simplified_node_name = node->name();
1865       AddToOptimizationQueue(node);
1866       return Status::OK();
1867     }
1868 
1869     // 2. If the reshape is a no-op, forward its input to its consumers, unless
1870     // it anchors a control dependency since we want to make sure that control
1871     // dependency is triggered.
1872     if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
1873       *simplified_node_name = node->input(0);
1874       return Status::OK();
1875     }
1876 
1877     return Status::OK();
1878   }
1879 
1880  private:
1881   // Returns whether `reshape` is an identity op.
ReshapeIsIdentity(const NodeDef & reshape)1882   bool ReshapeIsIdentity(const NodeDef& reshape) {
1883     OpInfo::TensorProperties reshape_props;
1884     OpInfo::TensorProperties input_props;
1885 
1886     if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
1887         !GetTensorProperties(reshape.input(0), &input_props).ok()) {
1888       return false;
1889     }
1890 
1891     return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
1892   }
1893 };
1894 
1895 // Reorder casting and value-preserving ops if beneficial.
1896 //
1897 // Original motivation: A common pattern after the layout optimizer is
1898 // casting an uint8 NHWC image to float before transposing it to NCHW. It
1899 // is beneficial to reorder the cast and the transpose to make the transpose
1900 // process smaller amount of data. More generally, this optimization converts
1901 //   Op(Cast(tensor, dst_type))
1902 // to
1903 //   Cast(Op(tensor), dst_type)
1904 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
1905 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
1906 // this optimization converts
1907 //   Cast(Op(tensor), dst_type)
1908 // to
1909 //   Op(Cast(tensor, dst_type))
1910 // when sizeof(tensor.type) > sizeof(dst_type)
1911 //
1912 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
1913  public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1914   explicit ReorderCastLikeAndValuePreserving(
1915       const GraphOptimizerContext& ctx,
1916       const ArithmeticOptimizerContext& ctx_ext)
1917       : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
1918                                  ctx_ext) {}
1919   ~ReorderCastLikeAndValuePreserving() override = default;
1920 
IsSupported(const NodeDef * node) const1921   bool IsSupported(const NodeDef* node) const override {
1922     return (IsValuePreserving(*node) || IsCastLike(*node)) &&
1923            !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
1924            !IsControlFlow(*node) && !IsInPreserveSet(*node);
1925   }
1926 
TrySimplify(NodeDef * consumer,string * simplified_node_name)1927   Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
1928     NodeDef* producer;
1929     TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
1930     const bool producer_is_cast = IsCastLike(*producer);
1931     const bool can_optimize =
1932         !IsCheckNumerics(*producer) &&
1933         ((producer_is_cast && IsValuePreserving(*consumer)) ||
1934          (IsValuePreserving(*producer) && IsCastLike(*consumer)));
1935     if (!can_optimize || IsControlFlow(*producer) ||
1936         producer->device() != consumer->device()) {
1937       return Status::OK();
1938     }
1939 
1940     const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
1941     const OpDef* cast_like_op_def = nullptr;
1942     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
1943                                                          &cast_like_op_def));
1944     DataType cast_src_type;
1945     TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
1946                                         &cast_src_type));
1947     DataType cast_dst_type;
1948     TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
1949                                          &cast_dst_type));
1950     if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
1951       return Status::OK();
1952     } else if (producer_is_cast &&
1953                DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
1954       return Status::OK();
1955     } else if (!producer_is_cast &&
1956                DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
1957       return Status::OK();
1958     }
1959 
1960     // Check that nodes were not already optimized.
1961     const string optimized_producer_name = OptimizedNodeName(
1962         ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
1963     const string optimized_consumer_name = OptimizedNodeName(
1964         ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
1965     const bool is_already_optimized =
1966         ctx().node_map->NodeExists(optimized_consumer_name) ||
1967         ctx().node_map->NodeExists(optimized_producer_name);
1968     if (is_already_optimized) {
1969       return Status::OK();
1970     }
1971 
1972     // Add copies of consumer and producer in reverse order.
1973     NodeDef* input;
1974     TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
1975     // Create new producer node.
1976     NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
1977     new_producer->set_input(0, producer->input(0));
1978     ctx().node_map->AddOutput(input->name(), new_producer->name());
1979 
1980     // Create new consumer node.
1981     NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
1982     new_consumer->set_input(0, new_producer->name());
1983 
1984     NodeDef* new_value_preserving =
1985         producer_is_cast ? new_producer : new_consumer;
1986     const DataType new_input_type =
1987         producer_is_cast ? cast_src_type : cast_dst_type;
1988     // Update the input type of the value-preserving node. The input and
1989     // output types of the cast-like nodes remain the same.
1990     TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
1991     // Make sure there is a kernel registered for the value preserving op
1992     // with the new input type.
1993     TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
1994     ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
1995 
1996     AddToOptimizationQueue(new_producer);
1997     *simplified_node_name = new_consumer->name();
1998 
1999     return Status::OK();
2000   }
2001 
2002  private:
2003   // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2004   Status SetInputType(DataType dtype, NodeDef* node) {
2005     const OpDef* op_def = nullptr;
2006     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2007     const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2008     const string& type_attr_name = input_arg.type_attr();
2009     if (type_attr_name.empty()) {
2010       if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2011         return errors::InvalidArgument("Could not set input type of ",
2012                                        node->op(), " op to ",
2013                                        DataTypeString(dtype));
2014       } else {
2015         // Op has fixed input type that already matches dtype.
2016         return Status::OK();
2017       }
2018     }
2019     SetDataTypeToAttr(dtype, type_attr_name, node);
2020     return Status::OK();
2021   }
2022   // This optimization can be dangerous on devices other than CPU and
2023   // GPU. The transpose might not be implemented for image.type, or
2024   // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2025   bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2026     using str_util::StrContains;
2027 
2028     string task;
2029     string device;
2030 
2031     return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2032            (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2033   }
2034 
IsFixedSizeType(DataType dtype)2035   bool IsFixedSizeType(DataType dtype) {
2036     return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2037            !kQuantizedTypes.Contains(dtype);
2038   }
2039 };
2040 
2041 // Fold a multiply of a scalar into the following convolution. This folding
2042 // can jump across nodes that merely reorders data (such as reshape and
2043 // transpose). For example, we can optimize
2044 //
2045 //
2046 //         Conv2D                             Conv2D
2047 //        /      \                           /      \
2048 //    Transpose  weights*       ->     Transpose    Mul
2049 //       |                                |        /   \
2050 //      Mul                               |    weights  scale
2051 //     /   \                              |
2052 //   input  scale**                     input
2053 //
2054 //  *) weights must be a const
2055 // **) scale must be a const scalar
2056 //
2057 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2058 // constant-folded, also weights tend to be smaller than the activations.
2059 //
2060 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2061 // Conv?DBackpropInput.
2062 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2063  public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2064   explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2065                                 const ArithmeticOptimizerContext& ctx_ext)
2066       : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2067   ~FoldMultiplyIntoConv() override = default;
2068 
IsSupported(const NodeDef * node) const2069   bool IsSupported(const NodeDef* node) const override {
2070     return IsConv2D(*node) || IsConv3D(*node);
2071   }
2072 
TrySimplify(NodeDef * node,string * simplified_node_name)2073   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2074 #define TF_RETURN_IF_TRUE(...) \
2075   if ((__VA_ARGS__)) return Status::OK()
2076 
2077     NodeDef* conv = node;
2078 
2079     NodeDef* weights;
2080     TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2081 
2082     // Fold the multiply to conv only when the weights are constant, so the
2083     // multiply can be constant-folded.
2084     //
2085     // TODO(jingyue): When the weights aren't constant, this should also help
2086     // performance a bit and memory usage a lot, since the weights tend to be
2087     // smaller than the activations.
2088     TF_RETURN_IF_TRUE(!IsConstant(*weights));
2089 
2090     // Verify that this node was not already optimized.
2091     const string scaled_weights_node_name =
2092         OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2093                           strings::StrCat("scaled", "_", conv->name()));
2094 
2095     TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2096 
2097     // Find the tail of value preserving chain entering the Conv node.
2098     NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2099                                                   *ctx().nodes_to_preserve);
2100 
2101     NodeDef* source;
2102     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2103 
2104     // Check that value preserving chain is the only consumer of the Mul output.
2105     TF_RETURN_IF_TRUE(!IsMul(*source));
2106     TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2107 
2108     const NodeDef* mul = source;
2109 
2110     // TODO(jingyue): handle the case where `scale` is 0-th operand.
2111     NodeDef* scale;  // scalar multiplier fot the input tensor
2112     NodeDef* input;
2113     TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale));
2114     TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input));
2115 
2116     // Check that 'scale * weight' can be const folded.
2117     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2118     TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype", "value"}));
2119     TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2120     TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2121                       weights->attr().at("dtype").type());
2122 
2123     // Check that `scale` is a scalar.
2124     const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2125     bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2126                              scale_tensor.tensor_shape().dim_size() == 0;
2127     TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2128 
2129     // At this point all preconditions are met, and we safely do the rewrite.
2130     VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2131             << " mul=" << mul->name() << " weights=" << weights->name();
2132 
2133     // Create new node `scaled_weights`.
2134     NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2135     scaled_weights->set_op("Mul");
2136     scaled_weights->set_device(weights->device());
2137     (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2138     AddToOptimizationQueue(scaled_weights);
2139 
2140     // Link in its inputs.
2141     scaled_weights->add_input(conv->input(1));
2142     ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2143     scaled_weights->add_input(mul->input(1));
2144     ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2145     ForwardControlDependencies(scaled_weights, {source});
2146 
2147     // Update `conv`'s weights to `scaled_weights`.
2148     conv->set_input(1, scaled_weights->name());
2149     ctx().node_map->UpdateInput(conv->name(), weights->name(),
2150                                 scaled_weights->name());
2151     AddToOptimizationQueue(conv);
2152 
2153     // Update `tail` node to bypass `mul` because it's folded to the weights.
2154     tail->set_input(0, mul->input(0));
2155     ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2156     AddToOptimizationQueue(tail);
2157     *simplified_node_name = conv->name();
2158 
2159     return Status::OK();
2160 #undef TF_RETURN_IF_TRUE
2161   }
2162 };
2163 
2164 // Fold Transpose into matrix multiplication.
2165 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2166  public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2167   explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2168                                    const ArithmeticOptimizerContext& ctx_ext)
2169       : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2170   ~FoldTransposeIntoMatMul() override = default;
2171 
IsSupported(const NodeDef * node) const2172   bool IsSupported(const NodeDef* node) const override {
2173     return IsMatMul(*node);
2174   }
2175 
TrySimplify(NodeDef * node,string * simplified_node_name)2176   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2177     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2178     const string optimized_node_name = OptimizedNodeName(matmul);
2179     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2180 
2181     NodeDef* a;
2182     NodeDef* b;
2183     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2184     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2185 
2186     bool is_complex = false;
2187     if (node->op() != "SparseMatMul") {
2188       const DataType type = GetDataTypeFromAttr(*node, "T");
2189       is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2190     }
2191 
2192     const std::set<string> foldable_transpose_ops =
2193         !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
2194                     : (node->op() == "BatchMatMul"
2195                            ? std::set<string>{"ConjugateTranspose"}
2196                            : std::set<string>{"Transpose"});
2197 
2198     const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2199                                IsInnerMatrixTransposeNode(*a, ctx().node_map);
2200     const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2201                                IsInnerMatrixTransposeNode(*b, ctx().node_map);
2202     if (!a_is_foldable && !b_is_foldable) return Status::OK();
2203 
2204     NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2205 
2206     if (a_is_foldable) {
2207       const string attr_a =
2208           node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
2209       FlipBooleanAttr(attr_a, new_op);
2210       new_op->set_input(0, a->input(0));
2211       ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2212     }
2213 
2214     if (b_is_foldable) {
2215       const string attr_b =
2216           node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
2217       FlipBooleanAttr(attr_b, new_op);
2218       new_op->set_input(1, b->input(0));
2219       ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2220     }
2221 
2222     std::vector<const NodeDef*> deps_to_forward = {node};
2223     if (a_is_foldable) deps_to_forward.push_back(a);
2224     if (b_is_foldable) deps_to_forward.push_back(b);
2225     ForwardControlDependencies(new_op, deps_to_forward);
2226 
2227     return Status::OK();
2228   }
2229 
2230  private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2231   void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2232     const bool old_value =
2233         !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2234     (*node->mutable_attr())[attr_name].set_b(!old_value);
2235   }
2236 
2237   template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2238   bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2239     const T n = perm.size();
2240     if (n < 2) {
2241       return false;
2242     }
2243     for (T i = 0; i < n - 2; ++i) {
2244       if (perm[i] != i) {
2245         return false;
2246       }
2247     }
2248     return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2249   }
2250 
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2251   bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2252                                   const NodeMap* node_map) {
2253     if (transpose_node.op() != "Transpose" &&
2254         transpose_node.op() != "ConjugateTranspose") {
2255       return false;
2256     }
2257     const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2258     std::vector<int> perm32;
2259     if (ValuesFromConstNode(*perm_node, &perm32)) {
2260       return IsInnerMatrixTranspose(perm32);
2261     }
2262     std::vector<int64> perm64;
2263     if (ValuesFromConstNode(*perm_node, &perm64)) {
2264       return IsInnerMatrixTranspose(perm64);
2265     }
2266     return false;
2267   }
2268 };
2269 
2270 // Fold Transpose into matrix multiplication.
2271 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2272  public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2273   explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2274                                       const ArithmeticOptimizerContext& ctx_ext)
2275       : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2276   ~FoldConjugateIntoTranspose() override = default;
2277 
IsSupported(const NodeDef * node) const2278   bool IsSupported(const NodeDef* node) const override {
2279     return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2280   }
2281 
TrySimplify(NodeDef * node,string * simplified_node_name)2282   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2283     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2284     const string optimized_node_name = OptimizedNodeName(matmul);
2285     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2286 
2287     NodeDef* input;
2288     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2289 
2290     const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2291     const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2292 
2293     if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2294         IsConj(*conj_op)) {
2295       NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2296 
2297       // Flip the type of transpose op to absorb the conjugation.
2298       new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2299                                                        : "Transpose");
2300       new_op->set_input(0, input->input(0));
2301       ctx().node_map->UpdateInput(new_op->name(), node->name(),
2302                                   input->input(0));
2303       ForwardControlDependencies(new_op, {node, input});
2304       *simplified_node_name = new_op->name();
2305     }
2306 
2307     return Status::OK();
2308   }
2309 };
2310 
2311 // Replace Mul node with identical inputs with a Square.
2312 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2313  public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2314   explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2315                                 const ArithmeticOptimizerContext& ctx_ext)
2316       : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2317   ~ReplaceMulWithSquare() override = default;
2318 
IsSupported(const NodeDef * node) const2319   bool IsSupported(const NodeDef* node) const override {
2320     return IsMul(*node) && node->input(0) == node->input(1);
2321   }
2322 
TrySimplify(NodeDef * node,string * simplified_node_name)2323   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2324     const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2325     const string optimized_node_name = OptimizedNodeName(mul);
2326     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2327 
2328     const DataType type = GetDataTypeFromAttr(*node, "T");
2329     bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2330 
2331     string task;
2332     string device;
2333     bool is_on_cpu =
2334         DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2335         str_util::StrContains(device, DEVICE_CPU);
2336 
2337     if (!is_complex || is_on_cpu) {
2338       NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2339       new_square_node->set_op("Square");
2340       for (int i = 1; i < new_square_node->input_size(); ++i) {
2341         new_square_node->set_input(i - 1, new_square_node->input(i));
2342       }
2343       new_square_node->mutable_input()->RemoveLast();
2344       for (const string& input : new_square_node->input()) {
2345         ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2346       }
2347       *simplified_node_name = new_square_node->name();
2348     }
2349 
2350     return Status::OK();
2351   }
2352 };
2353 
2354 // Simplify aggregation (e.g. AddN) nodes:
2355 //
2356 // 1. Discard aggregate nodes with a single input and no control dependencies.
2357 //
2358 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
2359 //    deduping or other rewrites) so we can get rid of the sum entirely.
2360 //
2361 //    The expression (using AddN as an example of an aggregate op):
2362 //      AddN(x, x, x, ... ,x)
2363 //           <-- N terms -->
2364 //    can be rewritten to:
2365 //      Mul(Const(N), x))
2366 //
2367 class SimplifyAggregation : public ArithmeticOptimizerStage {
2368  public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2369   explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
2370                                const ArithmeticOptimizerContext& ctx_ext)
2371       : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
2372   ~SimplifyAggregation() override = default;
2373 
IsSupported(const NodeDef * node) const2374   bool IsSupported(const NodeDef* node) const override {
2375     return IsAggregate(*node) && NumNonControlInputs(*node) > 0 &&
2376            GetDataTypeFromAttr(*node, "T") !=
2377                DT_VARIANT;  // TODO(b/119787146): Enable for variants.
2378   }
2379 
TrySimplify(NodeDef * node,string * simplified_node_name)2380   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2381     // 1. Discard aggregate nodes with a single input and no control deps.
2382     if (node->input_size() == 1) {
2383       *simplified_node_name = node->input(0);
2384       return Status::OK();
2385     }
2386 
2387     // 2. Rewrite aggregations of N >= 2 identical terms.
2388 
2389     // All non-control inputs must be identical.
2390     bool all_equal = true;
2391     int num_inputs = 1;
2392     for (int i = 1; i < node->input_size(); ++i) {
2393       if (IsControlInput(node->input(i))) break;
2394       ++num_inputs;
2395       if (node->input(i) != node->input(0)) {
2396         all_equal = false;
2397         break;
2398       }
2399     }
2400     if (!all_equal) return Status::OK();
2401 
2402     // And node should not be optimized earlier.
2403     const NodeScopeAndName node_scope_and_name =
2404         ParseNodeScopeAndName(node->name());
2405     const string optimized_const_name =
2406         OptimizedNodeName(node_scope_and_name, "Const");
2407     const string optimized_mul_name =
2408         OptimizedNodeName(node_scope_and_name, "Mul");
2409 
2410     bool is_already_optimized =
2411         ctx().node_map->NodeExists(optimized_const_name) ||
2412         ctx().node_map->NodeExists(optimized_mul_name);
2413 
2414     if (is_already_optimized) return Status::OK();
2415 
2416     // At this point all preconditions are met, and we safely do the rewrite.
2417     VLOG(3) << "Simplify aggregation with identical inputs: node="
2418             << node->name() << " num_inputs=" << num_inputs;
2419 
2420     // 1. Create constant node with value N.
2421     const auto type = GetDataTypeFromAttr(*node, "T");
2422     Tensor t(type, TensorShape({}));
2423     Status status = SetTensorValue(type, num_inputs, &t);
2424     if (!status.ok()) {
2425       return errors::Internal("Failed to create const node: ",
2426                               status.error_message());
2427     }
2428 
2429     TensorValue value(&t);
2430     NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
2431     status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
2432                                             new_const_node);
2433     if (!status.ok()) {
2434       return errors::Internal("Failed to create const node: ",
2435                               status.error_message());
2436     }
2437     new_const_node->set_device(node->device());
2438     MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
2439                          ctx().optimized_graph, ctx().node_map);
2440     AddToOptimizationQueue(new_const_node);
2441 
2442     // 2. Replace the aggregate node with Mul(Const(N), x).
2443     NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
2444     new_mul_node->set_op("Mul");
2445     new_mul_node->set_device(node->device());
2446     SetDataTypeToAttr(type, "T", new_mul_node);
2447     new_mul_node->add_input(new_const_node->name());
2448     ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
2449     new_mul_node->add_input(node->input(0));
2450     ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
2451 
2452     ForwardControlDependencies(new_mul_node, {node});
2453     *simplified_node_name = new_mul_node->name();
2454 
2455     return Status::OK();
2456   }
2457 };
2458 
2459 class ConvertPowStage : public ArithmeticOptimizerStage {
2460  public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2461   explicit ConvertPowStage(const GraphOptimizerContext& ctx,
2462                            const ArithmeticOptimizerContext& ctx_ext)
2463       : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
2464 
IsSupported(const NodeDef * node) const2465   bool IsSupported(const NodeDef* node) const override {
2466     return IsPow(*node) &&
2467            ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
2468   }
2469 
TrySimplify(NodeDef * node,string * simplified_node_name)2470   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2471     const auto& pow_props =
2472         ctx().graph_properties->GetInputProperties(node->name())[1];
2473     PartialTensorShape shape(pow_props.shape());
2474     if (!shape.IsFullyDefined()) {
2475       // skip if p is not fully defined.
2476       return Status::OK();
2477     }
2478     if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
2479       Tensor pow(pow_props.dtype(), pow_props.shape());
2480       if (!pow.FromProto(pow_props.value())) {
2481         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2482                                        pow_props.value().DebugString());
2483       }
2484 
2485       complex128 prev, curr;
2486       for (int i = 0; i < pow.NumElements(); ++i) {
2487         if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
2488           // input data type is not supported by Pow. Skip.
2489           return Status::OK();
2490         }
2491         if (i != 0 && curr != prev) {
2492           // pow has different values on different elements. Skip.
2493           return Status::OK();
2494         }
2495         prev = curr;
2496       }
2497       NodeDef *x, *y;
2498       TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
2499       TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
2500       const auto& value_props =
2501           ctx().graph_properties->GetInputProperties(node->name())[0];
2502       const TensorShapeProto& output_shape =
2503           ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
2504       if (curr == complex128(2, 0)) {
2505         node->set_op("Square");
2506         node->set_input(1, AsControlDependency(y->name()));
2507         AddToOptimizationQueue(node);
2508         AddToOptimizationQueue(y);
2509       } else if (curr == complex128(1, 0) &&
2510                  ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2511         // Pow could be used to broadcast, so make sure the shapes of the two
2512         // arguments are identical before replacing Pow with Identity.
2513         node->set_op("Identity");
2514         node->set_input(1, AsControlDependency(y->name()));
2515         AddToOptimizationQueue(node);
2516         AddToOptimizationQueue(y);
2517       } else if (curr == complex128(0.5, 0)) {
2518         node->set_op("Sqrt");
2519         node->set_input(1, AsControlDependency(y->name()));
2520         AddToOptimizationQueue(node);
2521         AddToOptimizationQueue(y);
2522       } else if (curr == complex128(0, 0) &&
2523                  ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2524         PartialTensorShape shape(value_props.shape());
2525         if (!shape.IsFullyDefined()) {
2526           // skip if b is not fully defined.
2527           return Status::OK();
2528         }
2529         if (TensorShape::IsValid(value_props.shape()) &&
2530             value_props.has_value()) {
2531           Tensor base(value_props.dtype(), value_props.shape());
2532           if (!base.FromProto(value_props.value())) {
2533             return errors::InvalidArgument("Cannot parse tensor from proto: ",
2534                                            value_props.value().DebugString());
2535           }
2536           node->set_op("Const");
2537           Tensor c(base.dtype(), base.shape());
2538           for (int i = 0; i < c.NumElements(); ++i) {
2539             TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
2540           }
2541           (*node->mutable_attr())["dtype"].set_type(base.dtype());
2542           c.AsProtoTensorContent(
2543               (*node->mutable_attr())["value"].mutable_tensor());
2544           node->mutable_attr()->erase("T");
2545           node->set_input(0, AsControlDependency(x->name()));
2546           node->set_input(1, AsControlDependency(y->name()));
2547           AddToOptimizationQueue(node);
2548           AddToOptimizationQueue(x);
2549           AddToOptimizationQueue(y);
2550         }
2551       } else if (curr == complex128(-0.5, 0)) {
2552         node->set_op("Rsqrt");
2553         node->set_input(1, AsControlDependency(y->name()));
2554         AddToOptimizationQueue(node);
2555         AddToOptimizationQueue(y);
2556       } else if (curr == complex128(-1, 0)) {
2557         node->set_op("Reciprocal");
2558         node->set_input(1, AsControlDependency(y->name()));
2559         AddToOptimizationQueue(node);
2560         AddToOptimizationQueue(y);
2561       }
2562     }
2563     return Status::OK();
2564   }
2565 
2566  private:
SetElementToOne(int i,Tensor * t)2567   Status SetElementToOne(int i, Tensor* t) {
2568     switch (t->dtype()) {
2569       case DT_INT32:
2570         t->flat<int32>()(i) = 1;
2571         return Status::OK();
2572       case DT_INT64:
2573         t->flat<int64>()(i) = 1L;
2574         return Status::OK();
2575       case DT_FLOAT:
2576         t->flat<float>()(i) = 1.0f;
2577         return Status::OK();
2578       case DT_DOUBLE:
2579         t->flat<double>()(i) = 1.0;
2580         return Status::OK();
2581       case DT_COMPLEX64:
2582         t->flat<complex64>()(i) = complex64(1);
2583         return Status::OK();
2584       case DT_COMPLEX128:
2585         t->flat<complex128>()(i) = complex128(1);
2586         return Status::OK();
2587       default:
2588         return errors::InvalidArgument("Invalid data type: ", t->dtype());
2589     }
2590   }
2591 };
2592 
2593 class ConvertLog1pStage : public ArithmeticOptimizerStage {
2594  public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2595   explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
2596                              const ArithmeticOptimizerContext& ctx_ext)
2597       : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
2598   ~ConvertLog1pStage() override = default;
2599 
IsSupported(const NodeDef * node) const2600   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
2601 
TrySimplify(NodeDef * node,string * simplified_node_name)2602   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2603     NodeDef* input;
2604     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2605     if (!IsAdd(*input)) {
2606       return Status::OK();
2607     }
2608 
2609     if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
2610       return Status::OK();
2611     }
2612 
2613     bool modified = false;
2614     TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
2615     if (!modified) {
2616       TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
2617     }
2618     if (modified) {
2619       *simplified_node_name = node->name();
2620     }
2621     return Status::OK();
2622   }
2623 
2624  private:
TrySimplifyInternal(NodeDef * node,NodeDef * input,int i,int j,bool * modified)2625   Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
2626                              bool* modified) {
2627     const auto& t =
2628         ctx().graph_properties->GetInputProperties(input->name())[i];
2629     const auto& c =
2630         ctx().graph_properties->GetInputProperties(input->name())[j];
2631     for (int k = 0; k < c.shape().dim_size(); ++k) {
2632       // Skip if c shape is not fully determined.
2633       if (c.shape().dim(k).size() < 0) {
2634         return Status::OK();
2635       }
2636     }
2637     TensorShapeProto broadcast_shape;
2638     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2639       return Status::OK();
2640     }
2641     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2642       // skip if the non-constant tensor doesn't have the same shape after
2643       // broadcast.
2644       return Status::OK();
2645     }
2646     if (TensorShape::IsValid(c.shape()) && c.has_value()) {
2647       Tensor constant(c.dtype(), c.shape());
2648       if (!constant.FromProto(c.value())) {
2649         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2650                                        c.value().DebugString());
2651       }
2652       complex128 element;
2653       for (int k = 0; k < constant.NumElements(); ++k) {
2654         if (!GetElementUnexhaustive(constant, k,
2655                                     {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2656                                      DT_COMPLEX64, DT_COMPLEX128},
2657                                     &element)) {
2658           // input data type is not supported by log1p. Skip.
2659           return Status::OK();
2660         }
2661         if (element != complex128(1)) {
2662           // current element is not 1. Skip.
2663           return Status::OK();
2664         }
2665       }
2666       NodeDef *x, *y;
2667       TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
2668       TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
2669       node->set_op("Log1p");
2670       node->set_input(0, input->input(i));
2671       node->add_input(AsControlDependency(y->name()));
2672       ForwardControlDependencies(node, {input});
2673 
2674       AddToOptimizationQueue(node);
2675       AddToOptimizationQueue(input);
2676       AddToOptimizationQueue(x);
2677       AddToOptimizationQueue(y);
2678       *modified = true;
2679     }
2680     return Status::OK();
2681   }
2682 };
2683 
2684 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
2685  public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2686   explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
2687                              const ArithmeticOptimizerContext& ctx_ext)
2688       : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
2689   ~ConvertExpm1Stage() override = default;
2690 
IsSupported(const NodeDef * node) const2691   bool IsSupported(const NodeDef* node) const override {
2692     if (!IsSub(*node)) return false;
2693 
2694     NodeDef* input;
2695     if (!GetInputNode(node->input(0), &input).ok()) return false;
2696 
2697     return IsExp(*input);
2698   }
2699 
TrySimplify(NodeDef * node,string * simplified_node_name)2700   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2701     if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
2702       return Status::OK();
2703     }
2704 
2705     NodeDef* exp;
2706     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
2707     if (!IsExp(*exp)) {
2708       return Status::OK();
2709     }
2710 
2711     if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
2712       return Status::OK();
2713     }
2714 
2715     const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
2716     const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
2717     for (int k = 0; k < c.shape().dim_size(); ++k) {
2718       // Skip if c shape is not fully determined.
2719       if (c.shape().dim(k).size() < 0) {
2720         return Status::OK();
2721       }
2722     }
2723     TensorShapeProto broadcast_shape;
2724     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2725       return Status::OK();
2726     }
2727     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2728       // skip if the non-constant tensor doesn't have the same shape after
2729       // broadcast.
2730       return Status::OK();
2731     }
2732     if (TensorShape::IsValid(c.shape()) && c.has_value()) {
2733       Tensor constant(c.dtype(), c.shape());
2734       if (!constant.FromProto(c.value())) {
2735         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2736                                        c.value().DebugString());
2737       }
2738       complex128 element;
2739       for (int k = 0; k < constant.NumElements(); ++k) {
2740         if (!GetElementUnexhaustive(constant, k,
2741                                     {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2742                                      DT_COMPLEX64, DT_COMPLEX128},
2743                                     &element)) {
2744           // input data type is not supported by expm1. Skip.
2745           return Status::OK();
2746         }
2747         if (element != complex128(1)) {
2748           // current element is not 1. Skip.
2749           return Status::OK();
2750         }
2751       }
2752       NodeDef *exp_input, *ones;
2753       TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
2754       TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2755       node->set_op("Expm1");
2756       node->set_input(0, exp->input(0));
2757       node->set_input(1, AsControlDependency(ones->name()));
2758       ForwardControlDependencies(node, {exp});
2759 
2760       AddToOptimizationQueue(node);
2761       AddToOptimizationQueue(exp);
2762       AddToOptimizationQueue(exp_input);
2763       AddToOptimizationQueue(ones);
2764     }
2765     return Status::OK();
2766   }
2767 };
2768 
2769 // Performs conversions like:
2770 // Max(Sqrt(x)) => Sqrt(Max(x))
2771 // Checks for a max/min reduction over element-wise monotonic functions, such
2772 // as Sqrt, Sigmoid, Tanh, etc.
2773 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
2774  public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2775   explicit OptimizeMaxOrMinOfMonotonicStage(
2776       const GraphOptimizerContext& ctx,
2777       const ArithmeticOptimizerContext& ctx_ext)
2778       : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
2779                                  ctx_ext) {}
2780   ~OptimizeMaxOrMinOfMonotonicStage() override = default;
2781 
IsSupported(const NodeDef * node) const2782   bool IsSupported(const NodeDef* node) const override {
2783     return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
2784            IsArgMax(*node) || IsArgMin(*node);
2785   }
2786 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)2787   Status TrySimplify(NodeDef* reduction_node,
2788                      string* simplified_node_name) override {
2789     if (IsInPreserveSet(*reduction_node)) {
2790       return Status::OK();
2791     }
2792     NodeDef* inner_function;
2793     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
2794     // Optimize only if:
2795     // 0. inner_function is not in the preserve set,
2796     // 1. inner_function's Op is element-wise monotonic
2797     // 2. inner_function's output is not being consumed elsewhere.
2798     // 3. is monotonic increasing if reduction_node is a pooling operation
2799     //    since we don't have MinPool operations.
2800     bool is_non_decreasing = false;
2801     if (!IsInPreserveSet(*inner_function) &&
2802         IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
2803         ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
2804         (is_non_decreasing || !IsAnyMaxPool(*reduction_node))) {
2805       // Swap the first inputs of the inner function Op & the reduction Op.
2806       NodeDef* inner_input;
2807       TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
2808       reduction_node->set_input(0, inner_input->name());
2809       ctx().node_map->UpdateInput(reduction_node->name(),
2810                                   inner_function->name(), inner_input->name());
2811       inner_function->set_input(0, reduction_node->name());
2812       UpdateConsumers(reduction_node, inner_function->name());
2813       ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
2814                                   reduction_node->name());
2815       if (!is_non_decreasing) {
2816         // Flip Min<->Max if the function is non-increasing, e.g.
2817         // Max(Neg(x)) = Neg(Min(x)).
2818         const string opposite = FlipMinMax(*reduction_node);
2819         reduction_node->set_op(opposite);
2820       }
2821 
2822       if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
2823         // ArgMax(Sqrt(x)) = ArgMax(x)
2824         inner_function->set_op("Identity");
2825       }
2826 
2827       AddToOptimizationQueue(reduction_node);
2828       AddToOptimizationQueue(inner_function);
2829       AddToOptimizationQueue(inner_input);
2830     }
2831     return Status::OK();
2832   }
2833 
UpdateConsumers(NodeDef * node,const string & new_input)2834   void UpdateConsumers(NodeDef* node, const string& new_input) {
2835     const string& node_name = node->name();
2836     const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
2837     for (NodeDef* consumer : consumers) {
2838       for (int i = 0; i < consumer->input_size(); ++i) {
2839         if (consumer->input(i) == node_name && consumer->name() != new_input) {
2840           consumer->set_input(i, new_input);
2841           ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
2842         }
2843       }
2844       AddToOptimizationQueue(consumer);
2845     }
2846   }
2847 
2848  private:
FlipMinMax(const NodeDef & node)2849   string FlipMinMax(const NodeDef& node) {
2850     const string& op = node.op();
2851     if (IsAnyMax(node) || IsArgMax(node)) {
2852       return str_util::StringReplace(op, "Max", "Min", false);
2853     } else {
2854       return str_util::StringReplace(op, "Min", "Max", false);
2855     }
2856   }
2857 };
2858 
2859 // Replace a chain of type&shape preserving unary ops with a
2860 // '_UnaryOpsComposition' node.
2861 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
2862 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
2863 class UnaryOpsComposition : public ArithmeticOptimizerStage {
2864  public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2865   explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
2866                                const ArithmeticOptimizerContext& ctx_ext)
2867       : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
2868     // WARN: This should be consistent with unary_ops_composition.cc.
2869     // clang-format off
2870     supported_ops_ = {// Ops defined via Eigen scalar ops.
2871                       {"Abs",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2872                       {"Acos",       {DT_FLOAT,          DT_DOUBLE}},
2873                       {"Acosh",      {DT_FLOAT,          DT_DOUBLE}},
2874                       {"Asin",       {DT_FLOAT,          DT_DOUBLE}},
2875                       {"Asinh",      {DT_FLOAT,          DT_DOUBLE}},
2876                       {"Atan",       {DT_FLOAT,          DT_DOUBLE}},
2877                       {"Atanh",      {DT_FLOAT,          DT_DOUBLE}},
2878                       {"Ceil",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2879                       {"Cos",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2880                       {"Cosh",       {DT_FLOAT,          DT_DOUBLE}},
2881                       {"Expm1",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2882                       {"Exp",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2883                       {"Floor",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2884                       {"Inv",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2885                       {"Log",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2886                       {"Log1p",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2887                       {"Neg",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2888                       {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2889                       {"Rint",       {DT_FLOAT,          DT_DOUBLE}},
2890                       {"Round",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2891                       {"Rsqrt",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2892                       {"Sigmoid",    {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2893                       {"Sin",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2894                       {"Sinh",       {DT_FLOAT,          DT_DOUBLE}},
2895                       {"Sqrt",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2896                       {"Square",     {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2897                       {"Tan",        {DT_FLOAT,          DT_DOUBLE}},
2898                       {"Tanh",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2899                       // Additional ops that are not part of the Eigen.
2900                       {"Elu",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2901                       {"Relu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2902                       {"Relu6",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2903                       {"Selu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
2904     // clang-format on
2905   }
2906   ~UnaryOpsComposition() override = default;
2907 
IsSupported(const NodeDef * node) const2908   bool IsSupported(const NodeDef* node) const override {
2909     return CanOptimize(*node) &&
2910            // Check that this node was not already a root of a fused chain. If
2911            // graph optimization runs twice without pruning in between,
2912            // fused_nodes_ will not have this information.
2913            !ctx().node_map->NodeExists(OptimizedNodeName(*node));
2914   }
2915 
TrySimplify(NodeDef * root,string * simplified_node_name)2916   Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
2917     TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
2918     DataType dtype = root->attr().at("T").type();
2919 
2920     // Keep a trace of all supported input nodes that can be fused together.
2921     std::vector<string> op_nodes = {root->name()};
2922     std::vector<string> op_names = {root->op()};
2923 
2924     // Check if we should follow input(0) while building an op composition.
2925     const auto predicate_fn = [&](const NodeDef& input) {
2926       if (input.name() == root->name()) return true;
2927 
2928       bool follow_input_node =
2929           dtype == GetDataTypeFromAttr(input, "T") &&
2930           NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
2931           CanOptimize(input);
2932 
2933       if (follow_input_node) {
2934         op_nodes.push_back(input.name());
2935         op_names.push_back(input.op());
2936       }
2937 
2938       return follow_input_node;
2939     };
2940 
2941     NodeDef* last_op = GetTailOfChain(
2942         *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
2943 
2944     // We were not able to find a chain that can be replaced.
2945     if (op_names.size() == 1) return Status::OK();
2946 
2947     // Do not add fused nodes to any other chain.
2948     std::for_each(op_nodes.begin(), op_nodes.end(),
2949                   [this](const string& name) { AddToFusedNodes(name); });
2950 
2951     // Reverse the trace to get correct composition computation order.
2952     std::reverse(op_names.begin(), op_names.end());
2953 
2954     VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
2955             << str_util::Join(op_names, ", ") << "]";
2956 
2957     NodeDef* composition_node = ctx().optimized_graph->add_node();
2958     composition_node->set_name(OptimizedNodeName(*root));
2959     composition_node->set_op("_UnaryOpsComposition");
2960     composition_node->add_input(last_op->input(0));
2961     composition_node->set_device(root->device());
2962 
2963     auto attr = composition_node->mutable_attr();
2964     SetAttrValue(dtype, &(*attr)["T"]);
2965     SetAttrValue(op_names, &(*attr)["op_names"]);
2966 
2967     ctx().node_map->AddNode(composition_node->name(), composition_node);
2968     ctx().node_map->AddOutput(NodeName(last_op->input(0)),
2969                               composition_node->name());
2970 
2971     *simplified_node_name = composition_node->name();
2972 
2973     return Status::OK();
2974   }
2975 
2976  private:
CanOptimize(const NodeDef & node) const2977   bool CanOptimize(const NodeDef& node) const {
2978     DataType dtype = GetDataTypeFromAttr(node, "T");
2979     if (!IsSupported(node.op(), dtype)) {
2980       return false;
2981     }
2982     if (IsInPreserveSet(node)) {
2983       return false;
2984     }
2985     if (!NodeIsOnCpu(node)) {
2986       return false;
2987     }
2988     if (NodeIsAlreadyFused(node)) {
2989       return false;
2990     }
2991     return !(IsDrivenByControlDependency(node) ||
2992              DrivesControlDependency(node));
2993   }
2994 
2995   // UnaryOpsComposition is defined only for CPU.
NodeIsOnCpu(const NodeDef & node) const2996   bool NodeIsOnCpu(const NodeDef& node) const {
2997     using str_util::StartsWith;
2998 
2999     string task;
3000     string device;
3001 
3002     return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
3003            StartsWith(device, DEVICE_CPU);
3004   }
3005 
NodeIsAlreadyFused(const NodeDef & node) const3006   bool NodeIsAlreadyFused(const NodeDef& node) const {
3007     return fused_nodes_.count(node.name()) > 0;
3008   }
3009 
OptimizedNodeName(const NodeDef & node) const3010   string OptimizedNodeName(const NodeDef& node) const {
3011     return strings::StrCat(node.name(), "/unary_ops_composition");
3012   }
3013 
AddToFusedNodes(const string & name)3014   void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3015 
3016   // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3017   bool IsSupported(const string& op_name, DataType dtype) const {
3018     const auto it = supported_ops_.find(op_name);
3019     return it != supported_ops_.end() && it->second.count(dtype) > 0;
3020   }
3021 
3022   std::unordered_map<string, std::set<DataType>> supported_ops_;
3023   std::unordered_set<string> fused_nodes_;
3024 };
3025 
3026 // Replace operations of the form:
3027 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3028 // with
3029 //    a_i
3030 // when the strided slice index `i` is applied in the k'th axis.
3031 //
3032 // Similarly, replace operations of the form:
3033 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3034 // with
3035 //    expand_dims(a_i, axis=k)
3036 //
3037 // TODO(ebrevdo): Extend to also replace operations of the form
3038 //    concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3039 // with
3040 //    a_i,
3041 // when
3042 //    s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3043 // and slicing is in the k'th axis.
3044 class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
3045  public:
RemoveStackStridedSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3046   explicit RemoveStackStridedSliceSameAxis(
3047       const GraphOptimizerContext& ctx,
3048       const ArithmeticOptimizerContext& ctx_ext)
3049       : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3050                                  ctx_ext) {}
3051   ~RemoveStackStridedSliceSameAxis() override = default;
3052 
IsSupported(const NodeDef * node) const3053   bool IsSupported(const NodeDef* node) const override {
3054     return IsStridedSlice(*node);
3055   }
3056 
TrySimplify(NodeDef * node,string * simplified_node_name)3057   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3058     // *node is a StridedSlice NodeDef.
3059     NodeDef* pack;
3060 
3061     // Get the input and see if it's a Pack op.
3062     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3063     if (!IsPack(*pack)) return Status::OK();
3064 
3065     bool return_early;
3066     PartialTensorShape pack_output_shape;
3067     int pack_axis;
3068     TF_RETURN_IF_ERROR(
3069         CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3070     if (return_early) return Status::OK();
3071 
3072     int slice_start_value;
3073     bool found;
3074     TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3075                                     &slice_start_value, &found));
3076     if (!found) return Status::OK();
3077 
3078     return RewriteGraph(node, pack, slice_start_value, pack_axis,
3079                         simplified_node_name);
3080   }
3081 
3082  protected:
IsReallyConstant(const NodeDef & node) const3083   bool IsReallyConstant(const NodeDef& node) const {
3084     if (!IsConstant(node)) {
3085       return false;
3086     }
3087     // If the node is fed it's not constant anymore.
3088     return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
3089   }
3090 
GetConstantAsInt64(const NodeDef & node,DataType dtype,std::vector<int64> * values)3091   bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
3092                           std::vector<int64>* values) {
3093     if (dtype == DT_INT32) {
3094       std::vector<int32> values_int32;
3095       if (!ValuesFromConstNode(node, &values_int32)) {
3096         return false;
3097       }
3098       std::copy(values_int32.begin(), values_int32.end(),
3099                 std::inserter(*values, values->begin()));
3100       return true;
3101     } else {
3102       return ValuesFromConstNode(node, values);
3103     }
3104   }
3105 
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3106   Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3107                      PartialTensorShape* pack_output_shape, int* pack_axis,
3108                      bool* return_early) {
3109     *return_early = true;
3110     TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3111 
3112     *pack_axis = pack->attr().at("axis").i();
3113     auto slice_properties =
3114         ctx().graph_properties->GetInputProperties(node->name());
3115     if (slice_properties.empty() ||
3116         slice_properties[0].shape().unknown_rank()) {
3117       return Status::OK();
3118     }
3119     *pack_output_shape = slice_properties[0].shape();
3120     const int pack_input_rank = pack_output_shape->dims() - 1;
3121     if (*pack_axis < 0) {
3122       // The ndims of any input into Pack op is its output ndims - 1.
3123       *pack_axis += pack_input_rank;
3124     }
3125     if (*pack_axis < 0 || *pack_axis >= pack_input_rank) {
3126       return errors::InvalidArgument(
3127           "Pack node (", pack->name(),
3128           ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3129     }
3130     *return_early = false;
3131     return Status::OK();
3132   }
3133 
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int * slice_start_value,bool * found)3134   Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3135                       const PartialTensorShape& pack_output_shape,
3136                       int pack_axis, int* slice_start_value, bool* found) {
3137     *found = false;
3138     TF_RETURN_IF_ERROR(
3139         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3140                                 "new_axis_mask", "shrink_axis_mask"}));
3141 
3142     const int begin_mask = node->attr().at("begin_mask").i();
3143     const int end_mask = node->attr().at("end_mask").i();
3144     const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3145     const int new_axis_mask = node->attr().at("new_axis_mask").i();
3146     const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3147 
3148     // Check that the StridedSlice is one of these at pack_axis:
3149     //   [..., i, ...]
3150     //   [..., i:i+1, ...]
3151     //   [..., :1, ...]
3152     //   [..., -1:, ...]
3153     ///  [..., s_{pack_axis}-1:, ...]
3154     NodeDef* slice_begin;
3155     NodeDef* slice_end;
3156     NodeDef* slice_strides;
3157     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3158     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3159     TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3160 
3161     for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3162       if (!IsReallyConstant(*n)) return Status::OK();
3163     }
3164 
3165     Tensor slice_begin_t;
3166     Tensor slice_end_t;
3167     Tensor slice_strides_t;
3168 
3169     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3170     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3171       return Status::OK();
3172     }
3173     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3174     if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3175       return Status::OK();
3176     }
3177     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3178     if (!slice_strides_t.FromProto(
3179             slice_strides->attr().at("value").tensor())) {
3180       return Status::OK();
3181     }
3182     TensorShape processing_shape;
3183     TensorShape final_shape;
3184     bool is_identity;
3185     bool is_simple_slice;
3186     bool slice_dim0;
3187     gtl::InlinedVector<int64, 4> slice_begin_vec;
3188     gtl::InlinedVector<int64, 4> slice_end_vec;
3189     gtl::InlinedVector<int64, 4> slice_strides_vec;
3190     TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3191         &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3192         begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3193         &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3194         &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3195 
3196     if (!is_simple_slice) return Status::OK();
3197 
3198     int begin_index = -1;
3199     int64 begin_value = 0;
3200     for (int i = 0; i < slice_begin_vec.size(); ++i) {
3201       const int64 v = slice_begin_vec[i];
3202       if (v != 0) {
3203         if (begin_index != -1) {
3204           // At least two start values that are nonzero.
3205           return Status::OK();
3206         }
3207         begin_index = i;
3208         begin_value = v;
3209       }
3210     }
3211 
3212     int end_index = -1;
3213     int64 end_value = 0;
3214     for (int i = 0; i < slice_end_vec.size(); ++i) {
3215       const int64 v = slice_end_vec[i];
3216       if (v != pack_output_shape.dim_size(i)) {
3217         if (end_index != -1) {
3218           // At least two end values that are nonzero.
3219           return Status::OK();
3220         }
3221         end_index = i;
3222         end_value = v;
3223       }
3224     }
3225 
3226     if (begin_index == -1 && end_index == -1) return Status::OK();
3227     if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3228       // Somehow received different axes for begin/end slicing
3229       return Status::OK();
3230     }
3231     const int slice_axis = (begin_index == -1) ? end_index : begin_index;
3232     if (slice_axis != pack_axis) {
3233       // Not slicing on the same axis as the Pack op.
3234       return Status::OK();
3235     }
3236     *slice_start_value = (begin_index == -1) ? 0 : begin_value;
3237     const int64 slice_end_value =
3238         (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
3239     if (slice_end_value != *slice_start_value + 1) {
3240       // Not slicing a single value out.
3241       return Status::OK();
3242     }
3243 
3244     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3245       return errors::InvalidArgument(
3246           "Node ", node->name(), " requested invalid slice index ",
3247           *slice_start_value, " on axis ", slice_axis,
3248           " from tensor of shape: ", pack_output_shape.DebugString());
3249     }
3250 
3251     *found = true;  // slice_start_value is valid.
3252     return Status::OK();
3253   }
3254 
RewriteGraph(const NodeDef * node,const NodeDef * pack,int slice_start_value,int pack_axis,string * simplified_node_name)3255   Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
3256                       int slice_start_value, int pack_axis,
3257                       string* simplified_node_name) {
3258     OpInfo::TensorProperties input_slice_properties;
3259     NodeDef* input_slice;
3260     TF_RETURN_IF_ERROR(
3261         GetInputNode(pack->input(slice_start_value), &input_slice));
3262     TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
3263                                            &input_slice_properties));
3264     PartialTensorShape input_slice_shape(input_slice_properties.shape());
3265 
3266     OpInfo::TensorProperties output_properties;
3267     TF_RETURN_IF_ERROR(GetTensorProperties(
3268         strings::StrCat(node->name(), ":", 0), &output_properties));
3269     PartialTensorShape output_shape(output_properties.shape());
3270     NodeDef* output =
3271         AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
3272     if (input_slice_shape.IsCompatibleWith(output_shape)) {
3273       output->set_op("Identity");
3274       output->set_device(node->device());
3275       SetDataTypeToAttr(output_properties.dtype(), "T", output);
3276       output->add_input(input_slice->name());
3277     } else {
3278       NodeDef* axis = AddEmptyNode(
3279           OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
3280       axis->set_op("Const");
3281       axis->set_device(node->device());
3282       auto axis_attr = axis->mutable_attr();
3283       SetDataTypeToAttr(DT_INT32, "dtype", axis);
3284       auto* axis_t = (*axis_attr)["value"].mutable_tensor();
3285       axis_t->set_dtype(DT_INT32);
3286       axis_t->add_int_val(pack_axis);
3287       AddToOptimizationQueue(axis);
3288       output->set_op("ExpandDims");
3289       output->set_device(node->device());
3290       SetDataTypeToAttr(output_properties.dtype(), "T", output);
3291       output->add_input(input_slice->name());
3292       output->add_input(axis->name());
3293     }
3294 
3295     // Copy dependencies over.
3296     ForwardControlDependencies(output, {node, pack});
3297     AddToOptimizationQueue(output);
3298     *simplified_node_name = output->name();
3299 
3300     return Status::OK();
3301   }
3302 };
3303 
3304 }  // namespace
3305 
3306 class UniqueNodes {
3307  public:
FindOrAddRepresentative(NodeDef * node)3308   NodeDef* FindOrAddRepresentative(NodeDef* node) {
3309     uint64 sig = ComputeSignature(*node);
3310     std::vector<NodeDef*>& candidates = rep_[sig];
3311     for (auto& candidate : candidates) {
3312       if (SameNode(*candidate, *node)) {
3313         return candidate;
3314       }
3315     }
3316     candidates.push_back(node);
3317     return node;
3318   }
3319 
3320  private:
3321   uint64 ComputeSignature(const NodeDef& node);
3322   bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
3323 
3324   absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
3325   absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
3326 };
3327 
ComputeSignature(const NodeDef & node)3328 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
3329   auto it = memoized_signatures_.find(&node);
3330   if (it != memoized_signatures_.end()) return it->second;
3331 
3332   uint64 h = Hash64(node.op());
3333   h = Hash64Combine(Hash64(node.device()), h);
3334 
3335   for (const auto& input : node.input()) {
3336     const TensorId input_tensor = ParseTensorName(input);
3337     h = Hash64CombineUnordered(
3338         Hash64(input_tensor.node().data(), input_tensor.node().size()), h);
3339     h = Hash64CombineUnordered(std::hash<int>()(input_tensor.index()), h);
3340   }
3341   for (const auto& attr : node.attr()) {
3342     h = Hash64CombineUnordered(Hash64(attr.first), h);
3343     h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
3344   }
3345   memoized_signatures_.emplace(&node, h);
3346   return h;
3347 }
3348 
SameNode(const NodeDef & node1,const NodeDef & node2) const3349 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
3350   if (node1.op() != node2.op()) {
3351     return false;
3352   }
3353   if (node1.device() != node2.device()) {
3354     return false;
3355   }
3356   if (node1.input_size() != node2.input_size()) {
3357     return false;
3358   }
3359   if (node1.attr_size() != node2.attr_size()) {
3360     return false;
3361   }
3362 
3363   // Compare inputs.
3364   if (IsCommutative(node1)) {
3365     std::vector<string> inputs1(node1.input().begin(), node1.input().end());
3366     std::sort(inputs1.begin(), inputs1.end());
3367     std::vector<string> inputs2(node2.input().begin(), node2.input().end());
3368     std::sort(inputs2.begin(), inputs2.end());
3369     return inputs1 == inputs2;
3370   } else {
3371     // The order or ordinary inputs matters.
3372     int index = 0;
3373     for (; index < node1.input_size(); ++index) {
3374       if (IsControlInput(node1.input(index))) {
3375         break;
3376       } else if (node1.input(index) != node2.input(index)) {
3377         return false;
3378       }
3379     }
3380     // The order of control inputs does not matter.
3381     if (index < node1.input_size()) {
3382       std::vector<string> ctrl_inputs1(node1.input().begin() + index,
3383                                        node1.input().end());
3384       std::sort(ctrl_inputs1.begin(), ctrl_inputs1.end());
3385       std::vector<string> ctrl_inputs2(node2.input().begin() + index,
3386                                        node2.input().end());
3387       std::sort(ctrl_inputs2.begin(), ctrl_inputs2.end());
3388       return ctrl_inputs1 != ctrl_inputs2;
3389     }
3390   }
3391 
3392   // Compare attributes.
3393   if (node1.attr().size() != node2.attr().size()) {
3394     return false;
3395   }
3396   for (const auto& attr1 : node1.attr()) {
3397     auto it = node2.attr().find(attr1.first);
3398     if (it == node2.attr().end()) return false;
3399     if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
3400   }
3401 
3402   return true;
3403 }
3404 
CanDedup(const NodeDef & node) const3405 bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
3406   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
3407     return false;
3408   }
3409   if (IsEnter(node) || IsExit(node)) {
3410     return false;
3411   }
3412   if (node.device().find("SPU") != string::npos) {
3413     return false;
3414   }
3415   // Workaround for Assert and Print mistakenly being labeled as stateful.
3416   if (IsAssert(node) || IsPrint(node)) {
3417     return true;
3418   }
3419   return IsFreeOfSideEffect(node);
3420 }
3421 
DedupComputations()3422 void ArithmeticOptimizer::DedupComputations() {
3423   GraphTopologyView graph_view;
3424   if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
3425     LOG(WARNING) << "Failed to initialize GraphTopologyView.";
3426     return;
3427   }
3428 
3429   const absl::flat_hash_set<string> ops_to_traverse = {
3430       "Identity", "IdentityN", "Reshape", "ExpandDims",
3431       "Enter",    "Switch",    "Merge"};
3432 
3433   // Populate feed_inplace_op;
3434   absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
3435 
3436   for (const NodeDef& root : optimized_graph_->node()) {
3437     if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
3438 
3439     if (ModifiesInputsInPlace(root)) {
3440       const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
3441         return node->op() == root.op() || ops_to_traverse.count(node->op()) > 0;
3442       };
3443 
3444       DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
3445                    DfsPredicates::Advance(is_continue_traversal),
3446                    DfsCallbacks::PreOrder([&](const NodeDef* node) {
3447                      feeds_inplace_op.insert(node);
3448                    }));
3449     }
3450   }
3451 
3452   bool stop = true;
3453   std::set<int> duplicates;
3454   UniqueNodes nodes;
3455   do {
3456     stop = true;
3457     for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3458       if (duplicates.find(i) != duplicates.end()) {
3459         continue;
3460       }
3461       NodeDef* node = optimized_graph_->mutable_node(i);
3462       if (!CanDedup(*node) ||
3463           feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
3464         continue;
3465       }
3466       NodeDef* rep = nodes.FindOrAddRepresentative(node);
3467       if (rep == node) {
3468         continue;
3469       }
3470       // If either node or rep feeds an inplace op, deduping them may cause data
3471       // races. For example: If we dedup nodes initializing two independent
3472       // inplace accumulations, they will write to the same buffer, clobbering
3473       // each other's results.
3474       if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
3475         continue;
3476       }
3477       VLOG(3) << "Remove duplicated node: node=" << node->name()
3478               << " representative=" << rep->name();
3479       const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
3480       std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
3481       for (NodeDef* fanout : fanouts) {
3482         for (int i = 0; i < fanout->input_size(); ++i) {
3483           string* fanout_input = fanout->mutable_input(i);
3484           const int position =
3485               NodePositionIfSameNode(*fanout_input, node->name());
3486           // Update name in-place.
3487           if (position < -1) {
3488             continue;
3489           } else if (position > 0) {
3490             *fanout_input = StrCat(rep->name(), ":", position);
3491           } else if (position == 0) {
3492             *fanout_input = rep->name();
3493           } else {
3494             *fanout_input = StrCat("^", rep->name());
3495           }
3496           node_map_->AddOutput(rep->name(), fanout->name());
3497         }
3498       }
3499       duplicates.insert(i);
3500       stop = false;
3501     }
3502   } while (!stop);
3503 
3504   // Delete duplicates
3505   if (fetch_nodes_known_ && !duplicates.empty()) {
3506     EraseNodesFromGraph(duplicates, optimized_graph_);
3507     // Rebuild the NodeMap which was invalidated by the node  swapping above.
3508     node_map_.reset(new NodeMap(optimized_graph_));
3509   }
3510 }
3511 
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)3512 void ArithmeticOptimizer::ForwardControlDependencies(
3513     NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
3514   for (const auto& src : src_nodes) {
3515     for (int i = src->input_size() - 1; i >= 0; --i) {
3516       if (IsControlInput(src->input(i))) {
3517         *target_node->add_input() = src->input(i);
3518         node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
3519       } else {
3520         break;
3521       }
3522     }
3523   }
3524   DedupControlInputs(target_node);
3525 }
3526 
SimplifyArithmeticOps(bool can_use_shapes)3527 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
3528   SetVector<NodeDef*> nodes_to_simplify;
3529   nodes_to_simplify.Reserve(optimized_graph_->node_size());
3530   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3531     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
3532   }
3533 
3534   const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
3535                                   graph_properties_.get(), node_map_.get(),
3536                                   &feed_nodes_, opt_level_);
3537   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
3538 
3539   // Stop pipeline after first stage returning non-empty simplified tensor name.
3540   const auto stop = [](const string& result) { return !result.empty(); };
3541   GraphOptimizerStagePipeline<string> pipeline(stop);
3542 
3543   if (options_.combine_add_to_addn && can_use_shapes)
3544     pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
3545   if (options_.fold_conjugate_into_transpose)
3546     pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
3547   if (options_.fold_multiply_into_conv)
3548     pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
3549   if (options_.fold_transpose_into_matmul)
3550     pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
3551   if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
3552     pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
3553   if (options_.minimize_broadcasts && can_use_shapes)
3554     pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
3555   if (options_.remove_identity_transpose && can_use_shapes)
3556     pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
3557   if (options_.remove_involution)
3558     pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
3559   if (options_.remove_redundant_bitcast)
3560     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
3561   if (options_.remove_redundant_cast)
3562     pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
3563   if (options_.remove_redundant_reshape)
3564     pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
3565   if (options_.remove_negation)
3566     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
3567   if (options_.replace_mul_with_square)
3568     pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
3569   if (options_.remove_logical_not)
3570     pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
3571   if (options_.reorder_cast_like_and_value_preserving)
3572     pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
3573   if (options_.simplify_aggregation)
3574     pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
3575   if (options_.hoist_cwise_unary_chains)
3576     pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
3577   if (options_.convert_sqrt_div_to_rsqrt_mul)
3578     pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
3579   if (options_.remove_idempotent)
3580     pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
3581   if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
3582   if (options_.convert_log1p)
3583     pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
3584   if (options_.convert_log_softmax)
3585     pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
3586   if (options_.optimize_max_or_min_of_monotonic)
3587     pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
3588   if (options_.convert_expm1)
3589     pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
3590   if (options_.unary_ops_composition)
3591     pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
3592   if (options_.remove_stack_strided_slice_same_axis)
3593     pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
3594   if (options_.fuse_squared_diff)
3595     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
3596 
3597   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
3598           << str_util::Join(pipeline.StageNames(), ", ");
3599 
3600   while (!nodes_to_simplify.Empty()) {
3601     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3602     NodeDef* node = nodes_to_simplify.PopBack();
3603 
3604     string simplified_tensor = "";
3605     bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
3606 
3607     // If the node was not optimized by any of the stages, go to the next one.
3608     if (!optimized) continue;
3609 
3610     // re-wire consumers of an old node to the new one
3611     if (NodeName(simplified_tensor) != node->name()) {
3612       // Always consider simplified_tensor for further optimizations.
3613       NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
3614       if (simplified_node != nullptr) {
3615         nodes_to_simplify.PushBack(simplified_node);
3616       }
3617       // When `node` is simplified to another node rather than in-place, the
3618       // consumers of `node` are already redirected to `simplified_tensor`.
3619       // Re-push the consumers into `nodes_to_simplify` for further
3620       // optimizations.
3621       const std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
3622       std::vector<NodeDef*> consumers(outputs.begin(), outputs.end());
3623       std::sort(consumers.begin(), consumers.end(),
3624                 [](const NodeDef* n1, const NodeDef* n2) {
3625                   return n1->name() < n2->name();
3626                 });
3627       for (NodeDef* consumer : consumers) {
3628         // Update `consumer`'s use of `node` to `input`'s operand.
3629         for (int i = 0; i < consumer->input_size(); ++i) {
3630           int operand_pos;
3631           string operand_node_name =
3632               ParseNodeName(consumer->input(i), &operand_pos);
3633           if (operand_node_name == node->name()) {
3634             *consumer->mutable_input(i) =
3635                 (operand_pos < 0
3636                      ? AsControlDependency(NodeName(simplified_tensor))
3637                      : simplified_tensor);
3638           }
3639         }
3640         node_map_->UpdateInput(consumer->name(), node->name(),
3641                                simplified_tensor);
3642         nodes_to_simplify.PushBack(consumer);
3643       }
3644     }
3645   }
3646   return Status::OK();
3647 }
3648 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)3649 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
3650                                      const GrapplerItem& item,
3651                                      GraphDef* optimized_graph) {
3652   // Set up helper data structures.
3653   nodes_to_preserve_ = item.NodesToPreserve();
3654   fetch_nodes_known_ = !item.fetch.empty();
3655   GrapplerItem optimized_item(item);
3656   optimized_graph_ = &optimized_item.graph;
3657   node_map_.reset(new NodeMap(optimized_graph_));
3658 
3659   for (const auto& feed : item.feed) {
3660     feed_nodes_.insert(NodeName(feed.first));
3661   }
3662 
3663   // Disable restricted graph rewrites.
3664   options_.unary_ops_composition &=
3665       item.optimization_options().allow_non_differentiable_rewrites;
3666 
3667   if (options_.dedup_computations) {
3668     DedupComputations();
3669   }
3670   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3671 
3672   // Perform topological sort on the graph in order to help AddOpsRewrite to
3673   // optimize larger subgraphs starting from the roots with more inputs.
3674   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
3675 
3676   graph_properties_.reset(new GraphProperties(optimized_item));
3677   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3678   const Status status = graph_properties_->InferStatically(assume_valid_feeds);
3679   const bool can_use_shapes = status.ok();
3680   if (!can_use_shapes) {
3681     VLOG(1) << "Shape inference failed." << status.error_message();
3682   }
3683 
3684   // Perform the optimizations.
3685   TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
3686 
3687   optimized_graph->Swap(optimized_graph_);
3688   return Status::OK();
3689 }
3690 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)3691 void ArithmeticOptimizer::Feedback(Cluster* /*cluster*/,
3692                                    const GrapplerItem& /*item*/,
3693                                    const GraphDef& /*optimized_graph*/,
3694                                    double /*result*/) {
3695   // Nothing to do for ArithmeticOptimizer.
3696 }
3697 
3698 }  // namespace grappler
3699 }  // namespace tensorflow
3700