1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/attr_value_util.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/grappler/costs/graph_properties.h"
37 #include "tensorflow/core/grappler/graph_topology_view.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
41 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
42 #include "tensorflow/core/grappler/utils.h"
43 #include "tensorflow/core/grappler/utils/canonicalizer.h"
44 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
45 #include "tensorflow/core/grappler/utils/topological_sort.h"
46 #include "tensorflow/core/grappler/utils/traversal.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/hash/hash.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/platform/errors.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/saved_tensor_slice_util.h"
56 #include "tensorflow/core/util/strided_slice_op.h"
57 
58 using tensorflow::strings::StrCat;
59 
60 namespace tensorflow {
61 namespace grappler {
62 namespace {
63 
64 // Mark nodes created or optimized by a stage with a tag.
65 constexpr char kAddOpsRewriteTag[] =
66     "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
67 constexpr char kMinimizeBroadcastsTag[] =
68     "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
69 
70 // Extract values from a Const op to `values`. Returns true if succeeds.
71 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)72 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
73   if (node.op() != "Const") {
74     return false;
75   }
76 
77   if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
78       node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
79     return false;
80   }
81 
82   // TensorProto represents the content of the tensor in either <type>_val or
83   // tensor_content.
84   const TensorProto& tensor = node.attr().at("value").tensor();
85   typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
86       checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
87 
88   if (!tensor_values->empty() && tensor.has_tensor_shape()) {
89     // When tensor_shape is set, theoretically the representation of the data
90     // could be compressed. So, before copying values to the returned vector,
91     // make sure no compression happens.
92     const TensorShapeProto& shape = tensor.tensor_shape();
93     if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
94       values->insert(values->end(), tensor_values->begin(),
95                      tensor_values->end());
96       return true;
97     }
98   }
99 
100   const auto tensor_content_size = tensor.tensor_content().size();
101   if (tensor_content_size > 0) {
102     CHECK_EQ(0, tensor_content_size % sizeof(T))
103         << "tensor_content_size (" << tensor_content_size
104         << ") is not a multiple of " << sizeof(T);
105     values->resize(tensor_content_size / sizeof(T));
106     port::CopyToArray(tensor.tensor_content(),
107                       reinterpret_cast<char*>(values->data()));
108     return true;
109   }
110 
111   return false;
112 }
113 
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)114 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
115                           GraphDef* graph, NodeMap* node_map) {
116   bool already_exists = false;
117   for (const string& input : node->input()) {
118     if (input == new_input || AsControlDependency(input) == new_input) {
119       already_exists = true;
120       break;
121     }
122   }
123   if (!already_exists) {
124     const string ctrl_dep =
125         ConstantFolding::AddControlDependency(new_input, graph, node_map);
126     node->add_input(ctrl_dep);
127     node_map->AddOutput(NodeName(new_input), node->name());
128   }
129   return !already_exists;
130 }
131 
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)132 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
133   (*node->mutable_attr())[attr_name].set_type(dtype);
134 }
135 
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)136 NodeDef* GetTailOfValuePreservingChain(
137     const NodeDef& node, const NodeMap& node_map,
138     const std::unordered_set<string>& nodes_to_preserve) {
139   auto is_value_preserving_non_branching = [&](const NodeDef& node) {
140     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
141            IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
142   };
143   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
144                         is_value_preserving_non_branching);
145 }
146 
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)147 NodeDef* GetTailOfIdempotentChain(
148     const NodeDef& node, const NodeMap& node_map,
149     const std::unordered_set<string>& nodes_to_preserve) {
150   auto is_idempotent_non_branching = [&](const NodeDef& node) {
151     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
152            IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
153   };
154   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
155                         is_idempotent_non_branching);
156 }
157 
158 // GetElementUnexhaustive tries to get the value of an element in a tensor and
159 // turn it into complex128 type. It only check for a limited number of data
160 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)161 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
162                             complex128* element) {
163   if (dtypes.find(t.dtype()) == dtypes.end()) return false;
164   switch (t.dtype()) {
165     case DT_BFLOAT16:
166       *element = complex128(t.flat<bfloat16>()(i));
167       return true;
168     case DT_HALF:
169       *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
170       return true;
171     case DT_INT32:
172       *element = complex128(t.flat<int32>()(i));
173       return true;
174     case DT_INT64:
175       *element = complex128(t.flat<int64>()(i));
176       return true;
177     case DT_FLOAT:
178       *element = complex128(t.flat<float>()(i));
179       return true;
180     case DT_DOUBLE:
181       *element = complex128(t.flat<double>()(i));
182       return true;
183     case DT_COMPLEX64:
184       *element = complex128(t.flat<complex64>()(i));
185       return true;
186     case DT_COMPLEX128:
187       *element = t.flat<complex128>()(i);
188       return true;
189     default:
190       return false;
191   }
192 }
193 
NodeIsOnCpu(const NodeDef & node)194 bool NodeIsOnCpu(const NodeDef& node) {
195   string task;
196   string device;
197   return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
198          absl::StrContains(device, DEVICE_CPU);
199 }
200 
201 // Graph optimizer context extension specific to ArithmeticOptimizer.
202 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon327bfa1e0111::ArithmeticOptimizerContext203   explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
204       : nodes_to_simplify(nodes_to_simplify) {}
205   SetVector<NodeDef*>* nodes_to_simplify;
206 };
207 
208 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
209 // AddOps optimization, etc...
210 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
211  public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)212   explicit ArithmeticOptimizerStage(const string& name,
213                                     const GraphOptimizerContext& ctx,
214                                     const ArithmeticOptimizerContext ctx_ext)
215       : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
216         ctx_ext_(ctx_ext) {}
217   ~ArithmeticOptimizerStage() override = default;
218 
219  protected:
220   // Simplification graph rewrite can create additional nodes that are inputs
221   // to final simplified node, they can be also added to the arithmetic
222   // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)223   void AddToOptimizationQueue(NodeDef* node) {
224     ctx_ext_.nodes_to_simplify->PushBack(node);
225   }
226 
227   // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
228   // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)229   void ForwardControlDependencies(
230       NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
231     for (const auto& src : src_nodes) {
232       for (int i = src->input_size() - 1; i >= 0; --i) {
233         if (IsControlInput(src->input(i))) {
234           *target_node->add_input() = src->input(i);
235           ctx().node_map->AddOutput(NodeName(src->input(i)),
236                                     target_node->name());
237         } else {
238           break;
239         }
240       }
241     }
242     DedupControlInputs(target_node);
243   }
244 
IsReallyConstant(const NodeDef & node) const245   bool IsReallyConstant(const NodeDef& node) const {
246     if (!IsConstant(node)) {
247       return false;
248     }
249     // If the node is fed it's not constant anymore.
250     return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
251   }
252 
IsInPreserveSet(const NodeDef & node) const253   bool IsInPreserveSet(const NodeDef& node) const {
254     return ctx().nodes_to_preserve->find(node.name()) !=
255            ctx().nodes_to_preserve->end();
256   }
257 
258   // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const259   bool IsDrivenByControlDependency(const NodeDef& node) const {
260     return std::any_of(
261         node.input().begin(), node.input().end(),
262         [](const string& input) { return IsControlInput(input); });
263   }
264 
265   // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const266   bool DrivesControlDependency(const NodeDef& node) const {
267     for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
268       for (int i = 0; i < output->input_size(); ++i) {
269         const TensorId tensor = ParseTensorName(output->input(i));
270         if (tensor.node() == node.name() && tensor.index() < 0) {
271           return true;
272         }
273       }
274     }
275     return false;
276   }
277 
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)278   bool GetTensorFromConstNode(const string& node_name_or_input,
279                               Tensor* tensor) {
280     const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
281     return node != nullptr && IsReallyConstant(*node) &&
282            CheckAttrExists(*node, "value").ok() &&
283            tensor->FromProto(node->attr().at("value").tensor());
284   }
285 
286  private:
287   // Extended context required for ArithmeticOptimizer.
288   const ArithmeticOptimizerContext ctx_ext_;
289 };
290 
291 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
292 // group of nodes from the optimized graph.
293 //
294 // * AddOpsRewrite:
295 //   Rewrite a group of Add/AddN with compact Add/AddN tree
296 //
297 // * MinimizeBroadcasts:
298 //   Rewrite a group of binary associative ops, reordering
299 //   inputs, to minimize the cost of broadcast
300 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
301  public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)302   explicit ArithmeticNodesGroupOptimizerStage(
303       const string& name, const GraphOptimizerContext& ctx,
304       const ArithmeticOptimizerContext ctx_ext)
305       : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
306   ~ArithmeticNodesGroupOptimizerStage() override = default;
307 
308   // Input name with a statically inferred shape from GraphProperties
309   struct InputAndShape {
InputAndShapetensorflow::grappler::__anon327bfa1e0111::ArithmeticNodesGroupOptimizerStage::InputAndShape310     InputAndShape(const string& input, const TensorShapeProto& shape)
311         : input(input), shape(shape) {}
312     string input;
313     TensorShapeProto shape;
314   };
315 
316   // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
317   // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
318   // obtained by graph traversal, starting from a root node.
319   struct OptimizedNodesGroup {
320     NodeDef* root_node;
321     TensorShapeProto root_shape;
322     // Optimized nodes that will be updated or removed by rewrite
323     std::vector<NodeDef*> optimized_nodes;
324     // Inputs to optimized nodes
325     std::vector<InputAndShape> inputs;
326   };
327 
TrySimplify(NodeDef * node,string * simplified_node_name)328   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
329     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
330 
331     OptimizedNodesGroup group;
332     TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
333 
334     if (!group.optimized_nodes.empty()) {
335       *simplified_node_name = RewriteOptimizedNodesGroup(group);
336     }
337 
338     return Status::OK();
339   }
340 
341  protected:
342   // Modify the optimized graph after nodes group was successfully identified
343   virtual string RewriteOptimizedNodesGroup(
344       const OptimizedNodesGroup& group) = 0;
345 
346   // Check if input can become a part of current optimized nodes group.
347   virtual bool IsAbsorbableByOptimizedNodesGroup(
348       const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
349 
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const350   Status AbsorbInputByOptimizedNodesGroup(const string& input,
351                                           OptimizedNodesGroup* group) const {
352     std::deque<const string*> input_tensors;
353     input_tensors.push_front(&input);
354 
355     while (!input_tensors.empty()) {
356       const string* input_tensor = input_tensors.front();
357       input_tensors.pop_front();
358 
359       // Get a node for the input tensor.
360       NodeDef* input_node;
361       TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
362 
363       if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
364         group->optimized_nodes.push_back(input_node);
365         for (int i = input_node->input_size() - 1; i >= 0; --i) {
366           const string& absorbed_node_input = input_node->input(i);
367           // TODO(ezhulenev): support control inputs
368           if (IsControlInput(absorbed_node_input)) continue;
369           input_tensors.push_front(&absorbed_node_input);
370         }
371       } else {
372         // If input node can't be absorbed, add it to OptimizedNodesGroup input.
373         const OpInfo::TensorProperties* properties;
374         TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
375         group->inputs.emplace_back(*input_tensor, properties->shape());
376       }
377     }
378 
379     return Status::OK();
380   }
381 
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const382   Status CreateOptimizedNodesGroup(NodeDef* root_node,
383                                    OptimizedNodesGroup* group) const {
384     const OpInfo::TensorProperties* root_node_output_properties;
385     TF_RETURN_IF_ERROR(
386         GetTensorProperties(root_node->name(), &root_node_output_properties));
387 
388     group->root_node = root_node;
389     group->root_shape = root_node_output_properties->shape();
390 
391     group->optimized_nodes.reserve(root_node->input_size());
392     for (int i = 0; i < root_node->input_size(); ++i) {
393       const string& input_i = root_node->input(i);
394       // TODO(ezhulenev): add support for control inputs
395       if (IsControlInput(input_i)) continue;
396       TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
397     }
398 
399     return Status::OK();
400   }
401 
402   // Check if all inputs can be broadcasted to the same shape
403   // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const404   bool HasAllInputsBroadcastableToShape(
405       const NodeDef& node, const OpInfo::TensorProperties& properties) const {
406     auto is_broadcastable = [this, &properties](const string& input) {
407       const OpInfo::TensorProperties* input_props;
408       Status has_input_properties = GetTensorProperties(input, &input_props);
409       return has_input_properties.ok() &&
410              ShapesBroadcastable(properties, *input_props);
411     };
412     return std::all_of(node.input().begin(), node.input().end(),
413                        is_broadcastable);
414   }
415 
ShapeSignature(const TensorShapeProto & shape) const416   string ShapeSignature(const TensorShapeProto& shape) const {
417     string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
418     for (int i = 0; i < shape.dim_size(); ++i)
419       strings::StrAppend(&signature, ":", shape.dim(i).size());
420     return signature;
421   }
422 
MarkWithTag(const StringPiece tag,NodeDef * node)423   void MarkWithTag(const StringPiece tag, NodeDef* node) {
424     AddNodeAttr(tag, true, node);
425   }
426 
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const427   void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
428                              const StringPiece tag) const {
429     AddNodeAttr(tag, true, group.root_node);
430     for (NodeDef* optimized_node : group.optimized_nodes) {
431       AddNodeAttr(tag, true, optimized_node);
432     }
433   }
434 
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const435   bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
436                          const NodeDef& node) const {
437     return group.root_node->device() == node.device();
438   }
439 
IsInPreserveSet(const NodeDef & node) const440   bool IsInPreserveSet(const NodeDef& node) const {
441     return ctx().nodes_to_preserve->find(node.name()) !=
442            ctx().nodes_to_preserve->end();
443   }
444 
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const445   bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
446     return HasNodeAttr(node, tag);
447   }
448 
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const449   bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
450                           const StringPiece tag2) const {
451     return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
452   }
453 };
454 
455 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
456 // original inputs of absorbed nodes.
457 //
458 // 1) All nodes must have the same device placement.
459 //
460 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
461 //    optimized to a single AddN node.
462 //
463 //                AddN_1
464 //             /    |    \
465 //          Add_1   z   Add_2       -> AddN(x, y, z, w, q, e)
466 //          /  \        /  \
467 //         x    y      w    Add_3
468 //                          / \
469 //                         q   e
470 //
471 // 3) If some nodes have different shape (it needs to be broadcastable to the
472 //    shape of a "root), tree is optimized to AddNs for symbolically equal
473 //    shapes, and a tree of Add ops, that minimize broadcasts.
474 //
475 //                AddN_1                                 Add
476 //             /    |    \                              /  \
477 //          Add_1   z   Add_2       ->               Add    w
478 //          /  \        /  \                        /   \
479 //         x    y      w    Add_3      AddN(x, y, q, e)  z
480 //                          / \
481 //                         q   e
482 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
483  public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)484   explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
485                               const ArithmeticOptimizerContext& ctx_ext)
486       : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
487   ~AddOpsRewriteStage() override = default;
488 
489   // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const490   bool IsSupported(const NodeDef* node) const override {
491     if (!CanOptimize(*node)) return false;
492 
493     // shape must be symbolically defined and all inputs compatible with it
494     const OpInfo::TensorProperties* properties;
495     Status has_properties = GetTensorProperties(node->name(), &properties);
496     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
497            HasAllInputsBroadcastableToShape(*node, *properties);
498   }
499 
500  protected:
501   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const502   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
503                                          const NodeDef& node) const override {
504     if (!CanOptimize(node)) return false;
505 
506     if (!IsOnTheSameDevice(group, node)) {
507       return false;
508     }
509     // with a single output data consumer (presumably if we reach this node from
510     // previously absorbed or a root node, it means that this node is not used
511     // as an input to any other op, outside of the group)
512     if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
513       return false;
514     }
515     // All input shapes must be broadcastable to the node shape
516     const OpInfo::TensorProperties* properties;
517     Status has_properties = GetTensorProperties(node.name(), &properties);
518     return has_properties.ok() &&
519            HasAllInputsBroadcastableToShape(node, *properties);
520   }
521 
522   // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const523   bool CanOptimize(const NodeDef& node) const {
524     // TODO(ezhulenev): check if AccumulateNV2 can be supported too
525     if (!IsAdd(node) && !IsAddN(node)) {
526       return false;
527     }
528     if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
529       return false;
530     }
531     // TODO(ezhulenev): relax this condition for root node
532     return !(IsDrivenByControlDependency(node) ||
533              DrivesControlDependency(node));
534   }
535 
536   // Rewrite a group of add ops into a single AddN if all input shapes are
537   // symbolically equal. If not, create AddN for equal shapes first, and then
538   // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)539   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
540     VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
541             << " op=" << group.root_node->op()
542             << " num_optimized_nodes=" << group.optimized_nodes.size()
543             << " num_inputs=" << group.inputs.size();
544 
545     // Do not optimize any of the nodes that are part of this group.
546     MarkAllMembersWithTag(group, kAddOpsRewriteTag);
547 
548     // All new nodes will be placed under the scope of a root node.
549     auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
550 
551     // Find what shapes are present in the inputs of absorbed nodes.
552     std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
553     for (const auto& input : group.inputs) {
554       shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
555     }
556 
557     using SigKV = decltype(shape_sig_to_inputs)::value_type;
558     VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
559             << " unique shapes: "
560             << absl::StrJoin(shape_sig_to_inputs, ", ",
561                              [](string* out, SigKV p) {
562                                strings::StrAppend(out, p.first);
563                              });
564 
565     // Collect all the shapes from representative elements.
566     std::vector<TensorShapeProto> shapes;
567     shapes.reserve(shape_sig_to_inputs.size());
568     for (const auto& el : shape_sig_to_inputs)
569       shapes.push_back(el.second[0].shape);
570 
571     // If all inputs have the same shape, rewrite whole group with a single AddN
572     if (shapes.size() == 1) {
573       string node_name = UniqueOptimizedNodeName(root_scope_and_name);
574       AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
575                                         group.inputs);
576       return node_name;
577     }
578 
579     // For inputs of different shapes:
580     // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
581     // 2. Build a tree of Add nodes, minimizing cost of broadcast
582     std::sort(shapes.begin(), shapes.end(),
583               [](const TensorShapeProto& left, const TensorShapeProto& right) {
584                 return CompareSymbolicallyShapedTensorSizes(left, right);
585               });
586 
587     // optimized name for leaf AddN nodes
588     auto leaf_node_name = [&root_scope_and_name, this](int i) {
589       return UniqueOptimizedNodeName(root_scope_and_name,
590                                      strings::StrCat("Leaf_", i));
591     };
592     // optimized name for internal nodes of a tree built up from AddN leaves
593     auto internal_node_name = [&root_scope_and_name, this](int i) {
594       return UniqueOptimizedNodeName(root_scope_and_name,
595                                      strings::StrCat("Internal_", i));
596     };
597 
598     // Add/AddN nodes that must be added to the tree
599     std::deque<InputAndShape> add_ops;
600 
601     // Prepare leaf AddN nodes for inputs of equal shape
602     for (int i = 0, end = shapes.size(); i < end; ++i) {
603       const auto node_name = leaf_node_name(i);
604       const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
605       add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
606                                                           node_name, inputs));
607     }
608 
609     // Build up a tree of Add ops
610     int internal_nodes = 0;
611     do {
612       const InputAndShape lhs = add_ops.front();
613       add_ops.pop_front();
614       const InputAndShape rhs = add_ops.front();
615       add_ops.pop_front();
616       string name = add_ops.empty()
617                         ? UniqueOptimizedNodeName(root_scope_and_name)
618                         : internal_node_name(internal_nodes++);
619       InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
620       add_ops.push_front(add);
621     } while (add_ops.size() > 1);
622 
623     InputAndShape optimized_root_node = add_ops.front();
624     return optimized_root_node.input;
625   }
626 
627   // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)628   InputAndShape AddInputsOfSymbolicallyEqualShape(
629       const NodeDef& root_node, const string& node_name,
630       const std::vector<InputAndShape>& inputs) {
631     CHECK(!inputs.empty()) << "Inputs must be non-empty";
632 
633     // Do not create redundant AddN nodes
634     if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
635       return inputs[0];
636     }
637 
638     // get shape from representative element
639     auto shape = inputs[0].shape;
640 
641     // copy attributes from a root node
642     DataType dtype = root_node.attr().at("T").type();
643 
644     // add new AddN node
645     NodeDef* node = AddEmptyNode(node_name);
646     node->set_op("AddN");
647     node->set_device(root_node.device());
648     (*node->mutable_attr())["T"].set_type(dtype);
649     (*node->mutable_attr())["N"].set_i(inputs.size());
650 
651     for (const auto& inputAndShape : inputs) {
652       ctx().node_map->AddOutput(inputAndShape.input, node_name);
653       node->add_input(inputAndShape.input);
654     }
655 
656     MarkWithTag(kAddOpsRewriteTag, node);
657     return InputAndShape(node_name, shape);
658   }
659 
660   // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)661   InputAndShape AddAggregatedInputs(const NodeDef& root_node,
662                                     const string& node_name,
663                                     const InputAndShape& left,
664                                     const InputAndShape& right) {
665     // copy attributes from a root node
666     DataType dtype = root_node.attr().at("T").type();
667 
668     // add new Add node
669     NodeDef* node = AddEmptyNode(node_name);
670     node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
671                                                                 : "AddV2");
672     node->set_device(root_node.device());
673     (*node->mutable_attr())["T"].set_type(dtype);
674     node->add_input(left.input);
675     node->add_input(right.input);
676 
677     ctx().node_map->AddOutput(left.input, node_name);
678     ctx().node_map->AddOutput(right.input, node_name);
679 
680     MarkWithTag(kAddOpsRewriteTag, node);
681     return InputAndShape(
682         node_name, TensorShapeProto());  // shape is not important at this point
683   }
684 };
685 
686 // Use the distributive property of multiplication and division over addition,
687 // along with commutativity of the former, to hoist common factors/denominators
688 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
689 // This pattern occurs frequently in regularization terms for the gradients
690 // during training.
691 //
692 // For example, we can rewrite an expression of the form:
693 //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
694 // to the following:
695 //   Mul(x, AddN(y1, y2, y3, ... yn))
696 // For division, we can rewrite
697 //   AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
698 // to:
699 //   Div(AddN(y1, y2, y3, ... yn), x)
700 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
701  public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)702   explicit HoistCommonFactorOutOfAggregation(
703       const GraphOptimizerContext& ctx,
704       const ArithmeticOptimizerContext& ctx_ext)
705       : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
706   ~HoistCommonFactorOutOfAggregation() override = default;
707 
IsSupported(const NodeDef * node) const708   bool IsSupported(const NodeDef* node) const override {
709     return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
710            !IsRewritten(node);
711   }
712 
TrySimplify(NodeDef * node,string * simplified_node_name)713   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
714     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
715 
716     bool common_factor_is_denominator = false;
717     std::set<string> common_factors;
718     std::vector<string> ctrl_deps;
719     TF_RETURN_IF_ERROR(GetCommonFactors(
720         node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
721 
722     if (common_factors.size() == 1) {
723       const string& common_factor = *common_factors.begin();
724 
725       // Gather up the non-shared factors
726       bool shapes_match = true;
727       std::vector<string> unique_factors;
728       TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
729                                           common_factor_is_denominator,
730                                           &shapes_match, &unique_factors));
731 
732       if (shapes_match) {
733         NodeDef* input_0;
734         TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
735 
736         // Use a copy of the first node for the outer multiplication/division.
737         NodeDef* new_outer_node = AddCopyNode(
738             OuterNodeName(node, common_factor_is_denominator), input_0);
739         // And a copy of aggregation node as one of the inner operands
740         NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
741 
742         new_outer_node->set_device(node->device());
743         if (common_factor_is_denominator) {
744           new_outer_node->set_input(0, new_add_node->name());
745           new_outer_node->set_input(1, common_factor);
746         } else {
747           new_outer_node->set_input(0, common_factor);
748           new_outer_node->set_input(1, new_add_node->name());
749         }
750 
751         ctx().node_map->AddOutput(common_factor, new_outer_node->name());
752         ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
753 
754         // Hoist non-shared factors up into the new AddN node.
755         for (int i = 0, end = unique_factors.size(); i < end; ++i) {
756           const string& unique_factor_i = unique_factors[i];
757           new_add_node->set_input(i, unique_factor_i);
758           ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
759         }
760 
761         // Add control deps on add node
762         for (const string& ctrl_dep : ctrl_deps) {
763           *new_add_node->add_input() = ctrl_dep;
764           ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
765         }
766 
767         // optimize new inner aggregation node
768         AddToOptimizationQueue(new_add_node);
769         // do not optimize the same node twice
770         rewritten_nodes_.insert(node->name());
771         *simplified_node_name = new_outer_node->name();
772       }
773     }
774     return Status::OK();
775   }
776 
777  private:
778   // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const779   string OuterNodeName(const NodeDef* node, bool is_div) const {
780     auto scope_and_name = ParseNodeScopeAndName(node->name());
781     return is_div ? OptimizedNodeName(scope_and_name, "Div")
782                   : OptimizedNodeName(scope_and_name, "Mul");
783   }
784 
785   // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const786   string InnerAddNodeName(const NodeDef* node) const {
787     auto scope_and_name = ParseNodeScopeAndName(node->name());
788     return OptimizedNodeName(scope_and_name, "AddV2");
789   }
790 
791   // Determine the set of common factors if the input nodes are all Mul or
792   // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const793   Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
794                           bool* common_factor_is_denominator,
795                           std::vector<string>* ctrl_deps) const {
796     CHECK(common_factors->empty());
797     CHECK_NOTNULL(common_factor_is_denominator);
798     *common_factor_is_denominator = false;
799 
800     bool has_mul = false;
801     bool has_div = false;
802     for (int i = 0; i < node->input_size(); ++i) {
803       if (i > 0 && common_factors->empty()) break;
804       if (IsControlInput(node->input(i))) {
805         ctrl_deps->push_back(node->input(i));
806         continue;
807       }
808       NodeDef* input;
809       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
810 
811       if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
812           (IsAnyDiv(*input) && has_mul)) {
813         // Break if input is neither a Mul or Div, or if there are both Mul &
814         // Div Ops.
815         common_factors->clear();
816         break;
817       } else if (IsAnyDiv(*input)) {
818         has_div = true;
819         // In case of possible common dividers, we avoid hoisting out if any
820         // input is not float/double, since integer division is not distributive
821         // over addition.
822         const OpInfo::TensorProperties* properties0;
823         const OpInfo::TensorProperties* properties1;
824         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
825         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
826         if (properties0->dtype() != DT_FLOAT &&
827             properties0->dtype() != DT_DOUBLE &&
828             properties1->dtype() != DT_FLOAT &&
829             properties1->dtype() != DT_DOUBLE) {
830           common_factors->clear();
831           break;
832         }
833       } else if (IsMul(*input)) {
834         has_mul = true;
835       }
836 
837       // We only focus on common factors from denominators if any Op is a
838       // Div.
839       std::set<string> factors_i =
840           has_mul ? std::set<string>{input->input(0), input->input(1)}
841                   : std::set<string>{input->input(1)};
842       if (i == 0) {
843         std::swap(*common_factors, factors_i);
844       } else {
845         std::set<string> intersection;
846         std::set_intersection(
847             factors_i.begin(), factors_i.end(), common_factors->begin(),
848             common_factors->end(),
849             std::inserter(intersection, intersection.begin()));
850         std::swap(*common_factors, intersection);
851       }
852       for (int i = 2; i < input->input_size(); ++i) {
853         ctrl_deps->push_back(input->input(i));
854       }
855     }
856 
857     *common_factor_is_denominator = has_div;
858     return Status::OK();
859   }
860 
861   // Gather up the non-shared factors (the y's in the example).
862   // Unless the aggregation is Add, we have to make sure that all the y's
863   // have the same shape since the other aggregation ops do not support
864   // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const865   Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
866                           const bool common_factor_is_denominator,
867                           bool* shapes_match,
868                           std::vector<string>* unique_factors) const {
869     *shapes_match = true;
870     unique_factors->reserve(node->input_size());
871 
872     for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
873       const string& input = node->input(i);
874       if (IsControlInput(input)) {
875         break;
876       }
877       NodeDef* inner_node;
878       TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
879       const int unique_factor_index =
880           common_factor_is_denominator
881               ? 0
882               : (inner_node->input(0) == common_factor ? 1 : 0);
883       unique_factors->push_back(inner_node->input(unique_factor_index));
884       if (i > 0 && !IsAdd(*node)) {
885         const OpInfo::TensorProperties* lhs;
886         const OpInfo::TensorProperties* rhs;
887         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
888         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
889         *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
890       }
891     }
892     return Status::OK();
893   }
894 
IsRewritten(const NodeDef * node) const895   bool IsRewritten(const NodeDef* node) const {
896     // if graph rewrite happens in multiple passes without graph pruning between
897     // them, it's possible that rewritten node already exists in a graph
898     return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
899            ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
900            ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
901            ctx().node_map->NodeExists(InnerAddNodeName(node));
902   }
903 
904   // keep names of the nodes that were optimized by this stage
905   std::unordered_set<string> rewritten_nodes_;
906 };
907 
908 // Binary associative ops can be re-ordered to minimize the number of broadcasts
909 // and the size of a temporary tensors.
910 //
911 // Example: [a, c] - scalars, [b, d] - matrices
912 //   @ - binary associative op (Add or Mul)
913 //   @* - broadcast
914 //
915 //           @                      @*
916 //        /     \                /      \
917 //      @*       @*      ->     @        @
918 //    /   \    /   \          /   \    /   \
919 //   a     b  c     d        a     c  b     d
920 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
921  public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)922   explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
923                               const ArithmeticOptimizerContext& ctx_ext)
924       : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
925   }
926   ~MinimizeBroadcasts() override = default;
927 
IsSupported(const NodeDef * node) const928   bool IsSupported(const NodeDef* node) const override {
929     if (!IsBinaryAssociative(*node)) return false;
930 
931     if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
932       return false;
933 
934     // has a symbolically defined shape with broadcastable inputs
935     const OpInfo::TensorProperties* properties;
936     Status has_properties = GetTensorProperties(node->name(), &properties);
937     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
938            HasAllInputsBroadcastableToShape(*node, *properties);
939   }
940 
941  protected:
IsBinaryAssociative(const NodeDef & node) const942   bool IsBinaryAssociative(const NodeDef& node) const {
943     return IsMul(node) || IsAdd(node);
944   }
945 
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const946   bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
947     return group.root_node->op() == node.op();
948   }
949 
950   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const951   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
952                                          const NodeDef& node) const override {
953     if (!IsSameOp(group, node)) {
954       return false;
955     }
956     if (IsInPreserveSet(node)) {
957       return false;
958     }
959     // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
960     if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
961       return false;
962     }
963     if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
964       return false;
965     }
966     if (!IsOnTheSameDevice(group, node)) {
967       return false;
968     }
969     // Optimized nodes updated in place, and that would break the graph, if the
970     // node has multiple output consumers
971     if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
972       return false;
973     }
974     // All input shapes must be broadcastable to the node shape
975     const OpInfo::TensorProperties* properties;
976     Status has_properties = GetTensorProperties(node.name(), &properties);
977     return has_properties.ok() &&
978            HasAllInputsBroadcastableToShape(node, *properties);
979   }
980 
CountUniqueShapes(const std::vector<InputAndShape> & inputs)981   std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
982     std::set<string> sigs;
983     for (const auto& ias : inputs) {
984       sigs.insert(ShapeSignature(ias.shape));
985     }
986     return sigs.size();
987   }
988 
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)989   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
990     VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
991             << " op=" << group.root_node->op()
992             << " num_optimized_nodes=" << group.optimized_nodes.size();
993 
994     // Do not optimize any of the nodes that are part of this group.
995     MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
996 
997     if (CountUniqueShapes(group.inputs) <= 1) {
998       VLOG(3) << "Skip min-bcast group with single unique shape";
999       // nothing to optimize when all shapes are the same
1000       return group.root_node->name();
1001     }
1002 
1003     auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1004     auto num_inputs = group.inputs.size();
1005     CHECK_EQ(num_nodes, num_inputs - 1)
1006         << "Can't build a tree with " << num_inputs << " inputs, using "
1007         << num_nodes << "binary op nodes.";
1008 
1009     std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1010     std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1011                                          group.optimized_nodes.end());
1012 
1013     // sort inputs by it's shape from smallest to largest
1014     std::stable_sort(add_ops.begin(), add_ops.end(),
1015                      [](const InputAndShape& lhs, const InputAndShape& rhs) {
1016                        return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1017                                                                    rhs.shape);
1018                      });
1019 
1020     // If there is an odd number of inputs, last one is the largest, and we want
1021     // to attach it to the root node, to build a well balanced tree.
1022     std::deque<InputAndShape> add_ops_leftover;
1023     if (add_ops.size() % 2 != 0) {
1024       add_ops_leftover.push_back(add_ops.back());
1025       add_ops.pop_back();
1026     }
1027 
1028     // At this point it's guaranteed that add_ops have even number of inputs.
1029     do {
1030       const InputAndShape lhs = add_ops.front();
1031       add_ops.pop_front();
1032       const InputAndShape rhs = add_ops.front();
1033       add_ops.pop_front();
1034 
1035       NodeDef* node;
1036       if (!optimized_nodes.empty()) {
1037         // re-purpose optimized nodes to build a new tree
1038         node = optimized_nodes.back();
1039         optimized_nodes.pop_back();
1040       } else {
1041         // or use root node if none optimized nodes left
1042         node = group.root_node;
1043       }
1044       InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1045 
1046       // Pushing updated node to the back of a deque will create a wide and
1047       // short tree, pushing to the front will create a tall tree. We prefer to
1048       // get a wide tree, it minimizes the potential number of temporary tensors
1049       // required to keep in memory, though sometimes we can go up to prevent
1050       // propagating a broadcast from leaves to the root. Example:
1051       //
1052       // inputs: [s, s, s, M] (s - scalar, M - matrix)
1053       // @* - op with broadcast
1054       //
1055       //  (only push_back)           @*     (push_front first op)
1056       //                            /  \
1057       //       @*                  @    M
1058       //     /   \                / \
1059       //    @     @*      ->     @   s
1060       //   / \   / \            / \
1061       //  s   s s   M          s   s
1062       if (add_ops.size() >= 2 &&
1063           CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1064                                                add_ops.at(1).shape)) {
1065         add_ops.push_front(updated_node);
1066       } else {
1067         add_ops.push_back(updated_node);
1068       }
1069     } while (add_ops.size() > 1);
1070     CHECK_EQ(1, add_ops.size());
1071 
1072     // attach the largest tensor to the root op
1073     if (!add_ops_leftover.empty()) {
1074       const InputAndShape lhs = add_ops.front();
1075       add_ops.pop_front();
1076       const InputAndShape rhs = add_ops_leftover.front();
1077       InputAndShape updated_node =
1078           UpdateInputs(lhs.input, rhs.input, group.root_node);
1079       add_ops.push_back(updated_node);
1080     }
1081 
1082     return add_ops.front().input;
1083   }
1084 
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1085   InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1086                              NodeDef* node) {
1087     string old_input_0 = node->input(0);
1088     string old_input_1 = node->input(1);
1089 
1090     // Update inputs only if they changed
1091     if (old_input_0 != input_0 || old_input_1 != input_1) {
1092       node->set_input(0, input_0);
1093       node->set_input(1, input_1);
1094       // Invalidate node properties (shape)
1095       ctx().graph_properties->ClearOutputProperties(node->name());
1096       ctx().graph_properties->ClearInputProperties(node->name());
1097       // Update the node map
1098       ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1099       ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1100       ctx().node_map->AddOutput(NodeName(input_0), node->name());
1101       ctx().node_map->AddOutput(NodeName(input_1), node->name());
1102       // Add updated node to optimization queue
1103       AddToOptimizationQueue(node);
1104     }
1105 
1106     TensorShapeProto shape;  // shape is not important at this point
1107     return InputAndShape(node->name(), shape);
1108   }
1109 };
1110 
1111 // Removes inverse transpose nodes
1112 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1113  public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1114   explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1115                                    const ArithmeticOptimizerContext& ctx_ext)
1116       : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1117   ~RemoveIdentityTranspose() override = default;
1118 
IsSupported(const NodeDef * node) const1119   bool IsSupported(const NodeDef* node) const override {
1120     return IsTranspose(*node) || IsConjugateTranspose(*node);
1121   }
1122 
TrySimplify(NodeDef * node,string * simplified_node_name)1123   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1124     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1125     NodeDef* tail = node;
1126     tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1127                                     *ctx().nodes_to_preserve);
1128     NodeDef* first_transpose;
1129     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1130 
1131     NodeDef* node_perm;
1132     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1133     if (!IsConstant(*node_perm)) {
1134       return Status::OK();
1135     }
1136     std::vector<int64> node_perm_values;
1137     TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1138     if (first_transpose->op() == node->op()) {
1139       // Remove pairs of transposes that cancel each other.
1140       NodeDef* first_transpose_perm;
1141       TF_RETURN_IF_ERROR(
1142           GetInputNode(first_transpose->input(1), &first_transpose_perm));
1143       if (!IsConstant(*first_transpose_perm)) {
1144         return Status::OK();
1145       }
1146       std::vector<int64> first_transpose_perm_values;
1147       TF_RETURN_IF_ERROR(
1148           GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1149       if (AreInversePermutations(node_perm_values,
1150                                  first_transpose_perm_values)) {
1151         if (tail == node) {
1152           // Bypass adjacent pair.
1153           *simplified_node_name = first_transpose->input(0);
1154         } else {
1155           // Bypass pair connected through chain.
1156           tail->set_input(0, first_transpose->input(0));
1157           ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1158                                       first_transpose->input(0));
1159           ForwardControlDependencies(tail, {first_transpose});
1160           *simplified_node_name = node->input(0);
1161         }
1162       }
1163     } else {
1164       // Remove simple identity transposes.
1165       if (IsIdentityPermutation(node_perm_values)) {
1166         if (IsConjugateTranspose(*node)) {
1167           const NodeScopeAndName transpose =
1168               ParseNodeScopeAndName(node->name());
1169           const string optimized_node_name = OptimizedNodeName(transpose);
1170           NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1171           new_op->set_op("Conj");
1172           new_op->mutable_input()->RemoveLast();
1173           new_op->mutable_attr()->erase("Tperm");
1174           ForwardControlDependencies(new_op, {node});
1175           *simplified_node_name = new_op->name();
1176         } else {
1177           *simplified_node_name = node->input(0);
1178         }
1179       }
1180     }
1181     return Status::OK();
1182   }
1183 
1184  private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1185   Status GetPermutation(const NodeDef& node_perm,
1186                         std::vector<int64>* perm64) const {
1187     std::vector<int> perm32;
1188     if (ValuesFromConstNode(node_perm, &perm32)) {
1189       perm64->reserve(perm32.size());
1190       for (int val : perm32) {
1191         perm64->push_back(static_cast<int64>(val));
1192       }
1193       return Status::OK();
1194     }
1195     if (ValuesFromConstNode(node_perm, perm64)) {
1196       return Status::OK();
1197     }
1198     return errors::InvalidArgument("Couldn't extract permutation from ",
1199                                    node_perm.name());
1200   }
1201 
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1202   bool AreInversePermutations(const std::vector<int64>& a,
1203                               const std::vector<int64>& b) {
1204     if (a.size() != b.size()) {
1205       return false;
1206     }
1207     for (int i = 0, end = a.size(); i < end; ++i) {
1208       if (a[b[i]] != i) {
1209         return false;
1210       }
1211     }
1212     return true;
1213   }
1214 
IsIdentityPermutation(const std::vector<int64> & perm)1215   bool IsIdentityPermutation(const std::vector<int64>& perm) {
1216     for (int64 i = 0, end = perm.size(); i < end; ++i) {
1217       if (i != perm[i]) {
1218         return false;
1219       }
1220     }
1221     return true;
1222   }
1223 };
1224 
1225 // An involution is an element-wise function f(x) that is its own inverse,
1226 // i.e. f(f(x)) = x. If we can find a chain of ops
1227 //   f->op1->op2->...opn->f
1228 // where op1 through opn preserve the values of their inputs, we can remove
1229 // the two instances of the involution from the graph, since they cancel
1230 // each other.
1231 class RemoveInvolution : public ArithmeticOptimizerStage {
1232  public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1233   explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1234                             const ArithmeticOptimizerContext& ctx_ext)
1235       : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1236   ~RemoveInvolution() override = default;
1237 
IsSupported(const NodeDef * node) const1238   bool IsSupported(const NodeDef* node) const override {
1239     return IsInvolution(*node);
1240   }
1241 
TrySimplify(NodeDef * node,string * simplified_node_name)1242   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1243     NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1244                                                   *ctx().nodes_to_preserve);
1245 
1246     NodeDef* involution;
1247     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1248 
1249     if (involution->op() == node->op()) {
1250       // Skip both *node and *involution since they cancel each other.
1251       if (tail == node) {
1252         // The two nodes to eliminate are adjacent.
1253         *simplified_node_name = involution->input(0);
1254       } else {
1255         tail->set_input(0, involution->input(0));
1256         ctx().node_map->UpdateInput(tail->name(), involution->name(),
1257                                     involution->input(0));
1258         *simplified_node_name = node->input(0);
1259       }
1260     }
1261 
1262     return Status::OK();
1263   }
1264 };
1265 
1266 // Remove redundant Bitcasts.
1267 // 1) Remove Bitcast whose source type and destination type are equal
1268 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1269 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1270  public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1271   explicit RemoveRedundantBitcastStage(
1272       const GraphOptimizerContext& ctx,
1273       const ArithmeticOptimizerContext& ctx_ext)
1274       : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1275   ~RemoveRedundantBitcastStage() override = default;
1276 
IsSupported(const NodeDef * node) const1277   bool IsSupported(const NodeDef* node) const override {
1278     return IsBitcast(*node);
1279   }
1280 
TrySimplify(NodeDef * node,string * simplified_node_name)1281   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1282     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1283 
1284     // Bypass Bitcast whose source type and destination type are equal.
1285     AttrSlice attrs(*node);
1286     DataType input_type;
1287     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1288     DataType output_type;
1289     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1290     if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1291       *simplified_node_name = node->input(0);
1292       return Status::OK();
1293     }
1294 
1295     NodeDef* bitcast;
1296     TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1297     NodeDef* operand;
1298     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1299 
1300     if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1301       AttrSlice operand_attrs(*operand);
1302       DataType operand_input_type;
1303       TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1304       // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1305       bitcast->set_input(0, operand->input(0));
1306       SetDataTypeToAttr(operand_input_type, "T", bitcast);
1307       ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1308                                   operand->input(0));
1309       AddToOptimizationQueue(bitcast);
1310       *simplified_node_name = bitcast->name();
1311     }
1312 
1313     return Status::OK();
1314   }
1315 };
1316 
1317 // Remove Casts whose source type and destination type are equal.
1318 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1319  public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1320   explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1321                                     const ArithmeticOptimizerContext& ctx_ext)
1322       : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1323   ~RemoveRedundantCastStage() override = default;
1324 
IsSupported(const NodeDef * node) const1325   bool IsSupported(const NodeDef* node) const override {
1326     return IsCast(*node) && !IsInPreserveSet(*node);
1327   }
1328 
TrySimplify(NodeDef * node,string * simplified_node_name)1329   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1330     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1331 
1332     // Bypass Cast whose source type and destination type are equal.
1333     AttrSlice attrs(*node);
1334     DataType input_type;
1335     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1336     DataType output_type;
1337     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1338     if (input_type == output_type) {
1339       *simplified_node_name = node->input(0);
1340     }
1341     return Status::OK();
1342   }
1343 };
1344 
1345 class RemoveNegationStage : public ArithmeticOptimizerStage {
1346  public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1347   explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1348                                const ArithmeticOptimizerContext& ctx_ext)
1349       : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1350   ~RemoveNegationStage() override = default;
1351 
IsSupported(const NodeDef * node) const1352   bool IsSupported(const NodeDef* node) const override {
1353     return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1354   }
1355 
TrySimplify(NodeDef * node,string * simplified_node_name)1356   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1357     NodeDef* x;
1358     NodeDef* y;
1359     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1360     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1361     bool updated = false;
1362     if (IsNeg(*y)) {
1363       // a - (-b) = a + b or  a + (-b) = a - b
1364       ForwardControlDependencies(node, {y});
1365       ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1366       node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1367       node->set_input(1, y->input(0));
1368       updated = true;
1369     } else if (IsAdd(*node) && IsNeg(*x)) {
1370       // (-a) + b = b - a
1371       ForwardControlDependencies(node, {x});
1372       ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1373       node->set_op("Sub");
1374       node->mutable_input()->SwapElements(0, 1);
1375       node->set_input(1, x->input(0));
1376       updated = true;
1377     }
1378     if (updated) {
1379       AddToOptimizationQueue(node);
1380     }
1381     return Status::OK();
1382   }
1383 };
1384 
1385 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1386  public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1387   explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1388                                  const ArithmeticOptimizerContext& ctx_ext)
1389       : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1390   ~RemoveLogicalNotStage() override = default;
1391 
IsSupported(const NodeDef * node) const1392   bool IsSupported(const NodeDef* node) const override {
1393     return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1394   }
1395 
TrySimplify(NodeDef * node,string * simplified_node_name)1396   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1397     const string node_name = node->name();
1398     NodeDef* input;
1399     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1400     if (IsInPreserveSet(*input) ||
1401         NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1402       return Status::OK();
1403     }
1404     string new_op;
1405     if (IsEqual(*input)) {
1406       new_op = "NotEqual";
1407     } else if (IsNotEqual(*input)) {
1408       new_op = "Equal";
1409     } else if (IsLess(*input)) {
1410       new_op = "GreaterEqual";
1411     } else if (IsLessEqual(*input)) {
1412       new_op = "Greater";
1413     } else if (IsGreater(*input)) {
1414       new_op = "LessEqual";
1415     } else if (IsGreaterEqual(*input)) {
1416       new_op = "Less";
1417     }
1418     if (!new_op.empty()) {
1419       input->set_op(new_op);
1420       *simplified_node_name = input->name();
1421     }
1422     return Status::OK();
1423   }
1424 };
1425 
1426 // This optimization hoists the common prefix of unary ops of the inputs to
1427 // concat out of the concat, for example:
1428 //    Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1429 // becomes
1430 //    Exp(Sin(Concat([x, y, z]))).
1431 // Similarly, it will hoist the common postfix of unary ops into Split or
1432 // SplitV nodes, for example:
1433 //    [Exp(Sin(y)) for y in Split(x)]
1434 // becomes
1435 //    [y for y in Split(Exp(Sin(x))]
1436 //
1437 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1438 // on the concat/split node.
1439 // TODO(rmlarsen): Handle Enter/Exit.
1440 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1441  public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1442   explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1443                                       const ArithmeticOptimizerContext& ctx_ext)
1444       : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1445 
1446   ~HoistCWiseUnaryChainsStage() override = default;
1447 
1448   struct ChainLink {
1449     ChainLink() = default;
ChainLinktensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1450     ChainLink(NodeDef* _node, int _port_origin)
1451         : node(_node), port_origin(_port_origin) {}
1452     NodeDef* node;    // Node in a chain.
1453     int port_origin;  // Port on concat/split node from which this chain
1454                       // originates.
1455 
operator <tensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1456     bool operator<(const ChainLink& other) const {
1457       if (port_origin < other.port_origin) {
1458         return true;
1459       } else if (port_origin > other.port_origin) {
1460         return false;
1461       } else {
1462         return node->name() < other.node->name();
1463       }
1464     }
1465   };
1466 
1467   // We use an ordinary set sorted on port and node name, so the order, and
1468   // hence the node name used for the hoisted chain, will be deterministic.
1469   using ChainLinkSet = std::set<ChainLink>;
1470 
IsSupported(const NodeDef * node) const1471   bool IsSupported(const NodeDef* node) const override {
1472     if (IsInPreserveSet(*node)) return false;
1473     if (IsConcat(*node) && node->attr().count("N") != 0) {
1474       const int n = node->attr().at("N").i();
1475       return n > 1 && FirstNInputsAreUnique(*node, n);
1476     } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1477                node->attr().count("num_split") != 0) {
1478       const int num_split = node->attr().at("num_split").i();
1479       if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1480         // TODO(rmlarsen): Remove this constraint when we have optimizations
1481         // in place for merging slices into splits.
1482         return false;
1483       }
1484       if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1485         // TODO(ezhulenev): Unary ops after Split might have a control path to
1486         // the Split node, and we currently do not properly handle cycles.
1487         return false;
1488       }
1489       return num_split > 1 && !IsAlreadyOptimized(*node);
1490     }
1491     return false;
1492   }
1493 
TrySimplify(NodeDef * node,string * simplified_node_name)1494   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1495     node_is_concat_ = IsConcat(*node);
1496     int prefix_length;
1497     std::set<string> ctrl_inputs;
1498     ChainLinkSet tails;
1499     TF_RETURN_IF_ERROR(
1500         FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1501     if (prefix_length > 0 && !tails.empty()) {
1502       TF_RETURN_IF_ERROR(
1503           HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1504     }
1505     return Status::OK();
1506   }
1507 
1508  private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1509   bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1510     if (n > node.input_size()) return false;
1511     absl::flat_hash_set<string> unique_inputs;
1512     const int start = node.op() == "Concat" ? 1 : 0;
1513     const int end = start + n;
1514     for (int i = start; i < end; ++i) {
1515       unique_inputs.insert(node.input(i));
1516     }
1517     int unique_input_size = unique_inputs.size();
1518     return unique_input_size == n;
1519   }
1520 
1521   // Returns the length of the common unary chain of ops that can be
1522   // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1523   Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1524                                 ChainLinkSet* tails,
1525                                 std::set<string>* ctrl_inputs) const {
1526     *prefix_length = 0;
1527     // Follow the chains starting at each concat input or split output as long
1528     // as all the following conditions hold:
1529     //   1. The ops in all chains are the same.
1530     //   2. The ops are unary elementwise op.
1531     //   3. The op output has only a single consumer (concat only).
1532     ChainLinkSet cur_tails;
1533     TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1534     if (cur_tails.size() < 2) {
1535       return Status::OK();
1536     }
1537     ctrl_inputs->clear();
1538     bool stop = false;
1539     while (!stop && !cur_tails.empty() &&
1540            OpsAreSafeToHoist(root_node, cur_tails)) {
1541       // We found one more link that can be hoisted.
1542       ++(*prefix_length);
1543       tails->swap(cur_tails);
1544       GatherControlInputs(ctrl_inputs, *tails);
1545 
1546       // Advance tail pointers to the next level.
1547       TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1548     }
1549     return Status::OK();
1550   }
1551 
1552   // Hoists the chains to the other side of concat or split and attaches the
1553   // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1554   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1555                            std::set<string>* ctrl_inputs, NodeDef* root_node) {
1556     VLOG(3) << "Hoist unary op chain:"
1557             << " root=" << root_node->DebugString()
1558             << " prefix_length=" << prefix_length << " ctrl_inputs=["
1559             << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1560 
1561     if (tails.empty()) {
1562       return Status::OK();
1563     }
1564     AddToOptimizationQueue(root_node);
1565     optimized_nodes_.insert(root_node->name());
1566     if (node_is_concat_) {
1567       AddControlInputs(ctrl_inputs, root_node);
1568       return HoistChainForConcat(prefix_length, tails, root_node);
1569     } else {
1570       return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1571     }
1572   }
1573 
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1574   void GatherControlInputs(std::set<string>* ctrl_inputs,
1575                            const ChainLinkSet& ops) const {
1576     for (const auto& link : ops) {
1577       const NodeDef* node = link.node;
1578       for (int i = node->input_size() - 1; i >= 0; --i) {
1579         const string& input = node->input(i);
1580         if (!IsControlInput(input)) break;
1581         ctrl_inputs->insert(input);
1582       }
1583     }
1584   }
1585 
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1586   void AddControlInputs(std::set<string>* new_ctrl_inputs,
1587                         NodeDef* node) const {
1588     for (int i = node->input_size() - 1; i >= 0; --i) {
1589       const string& existing_input = node->input(i);
1590       if (!IsControlInput(existing_input)) break;
1591       new_ctrl_inputs->erase(existing_input);
1592     }
1593     for (const string& new_input : *new_ctrl_inputs) {
1594       ctx().node_map->AddOutput(NodeName(new_input), node->name());
1595       node->add_input(new_input);
1596     }
1597   }
1598 
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1599   Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1600     if (node_is_concat_) {
1601       // Handle concat nodes by looking backwards in the graph.
1602       TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1603       const int n = node.attr().at("N").i();
1604       const int start = node.op() == "Concat" ? 1 : 0;
1605       const int end = start + n;
1606       if (end > node.input_size()) {
1607         return errors::FailedPrecondition("Got attr N=", n,
1608                                           " without enough inputs.");
1609       }
1610       // Set up tail pointers to point to the immediate inputs to Concat.
1611       for (int input_port = start; input_port < end; ++input_port) {
1612         if (IsControlInput(node.input(input_port))) {
1613           return errors::FailedPrecondition(
1614               "Got control input ", node.input(input_port),
1615               " where normal input was expected.");
1616         }
1617         NodeDef* tail;
1618         TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1619         tails->insert(ChainLink(tail, input_port));
1620       }
1621       return Status::OK();
1622     } else {
1623       // Handle split nodes by looking forwards in the graph.
1624       const auto& outputs = ctx().node_map->GetOutputs(node.name());
1625       for (NodeDef* output : outputs) {
1626         if (IsControlInput(output->input(0))) continue;
1627         TensorId tensor_id = ParseTensorName(output->input(0));
1628         if (tensor_id.node() == node.name()) {
1629           tails->insert(ChainLink(output, tensor_id.index()));
1630         } else {
1631           // This output node has a non-control input other than the split node,
1632           // abort.
1633           tails->clear();
1634           return Status::OK();
1635         }
1636       }
1637     }
1638     return Status::OK();
1639   }
1640 
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1641   bool OpsAreSafeToHoist(const NodeDef& root_node,
1642                          const ChainLinkSet& ops) const {
1643     if (ops.empty()) return true;
1644     const NodeDef* op0 = ops.begin()->node;
1645     if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1646     for (const auto& link : ops) {
1647       const NodeDef* op = link.node;
1648       if (op->device() != root_node.device() || op->op() != op0->op() ||
1649           IsInPreserveSet(*op)) {
1650         return false;
1651       }
1652       if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1653         // TODO(rmlarsen): Allow outgoing control edges.
1654         return false;
1655       }
1656       // Do not hoist Relu if it can be fused with its predecessors. This is
1657       // important because remapping runs after arithmetic.
1658       if (IsRelu(*op) || IsRelu6(*op)) {
1659         NodeDef* operand = nullptr;
1660         if (!GetInputNode(op->input(0), &operand).ok()) {
1661           return false;
1662         }
1663         if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1664           return false;
1665         }
1666       }
1667     }
1668     return true;
1669   }
1670 
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1671   Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1672                       bool* stop) const {
1673     *stop = true;
1674     new_tails->clear();
1675     for (const auto& link : tails) {
1676       const NodeDef* tail = link.node;
1677       if (node_is_concat_) {
1678         if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1679           return Status::OK();
1680         }
1681         NodeDef* new_tail;
1682         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1683         // Remember original port.
1684         new_tails->insert(ChainLink(new_tail, link.port_origin));
1685       } else {
1686         for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1687           const TensorId tensor = ParseTensorName(new_tail->input(0));
1688           if (tensor.node() != tail->name()) {
1689             return Status::OK();
1690           }
1691           // Skip control outputs.
1692           if (tensor.index() >= 0) {
1693             // Remember original port.
1694             new_tails->insert(ChainLink(new_tail, link.port_origin));
1695           }
1696         }
1697       }
1698     }
1699     *stop = false;
1700     return Status::OK();
1701   }
1702 
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1703   Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1704                              NodeDef* concat_node) {
1705     const string& concat_name = concat_node->name();
1706     const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1707     for (const auto& link : tails) {
1708       NodeDef* tail = CHECK_NOTNULL(link.node);
1709       const int concat_port = link.port_origin;
1710       CHECK_GE(concat_port, 0);
1711       CHECK_LT(concat_port, concat_node->input_size());
1712       const string concat_input = concat_node->input(concat_port);
1713       // Hook the node following tail directly into the concat node.
1714       const string tail_input = tail->input(0);
1715       concat_node->set_input(concat_port, tail_input);
1716       ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1717 
1718       if (concat_port == first_input) {
1719         // Update the consumers of concat to consume the end of the chain
1720         // instead.
1721         UpdateConsumers(concat_node, concat_input);
1722         // Reuse nodes in the first chain to process output of concat.
1723         tail->set_input(0, concat_name);
1724         ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1725       }
1726     }
1727     return Status::OK();
1728   }
1729 
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1730   Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1731                             std::set<string>* ctrl_inputs,
1732                             NodeDef* split_node) {
1733     // Create a new chain before the split node to process the input tensor.
1734     const string& split_name = split_node->name();
1735     auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1736 
1737     // We use the first tail node in the set as a template to get the list of
1738     // ops to apply (starting from the end).
1739     NodeDef* cur_tail = tails.begin()->node;
1740     NodeDef* cur_copy = AddCopyNode(
1741         OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1742     cur_copy->clear_input();
1743 
1744     // Update the split to take its input from the tail of the new chain.
1745     const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1746     const string orig_input = split_node->input(value_slot);
1747     split_node->set_input(value_slot, cur_copy->name());
1748     ctx().node_map->UpdateInput(split_node->name(), orig_input,
1749                                 cur_copy->name());
1750     TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1751 
1752     // Now walk backwards creating the rest of the chain.
1753     while (cur_tail != split_node) {
1754       NodeDef* new_copy = AddCopyNode(
1755           OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1756       new_copy->clear_input();
1757       cur_copy->add_input(new_copy->name());
1758       ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1759       cur_copy = new_copy;
1760       TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1761     }
1762     // Connect the original input to the head of the new chain.
1763     cur_copy->add_input(orig_input);
1764     ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1765                                  cur_copy->name());
1766     // Make sure all the control inputs are satisfied before running the first
1767     // node in the new chain.
1768     AddControlInputs(ctrl_inputs, cur_copy);
1769 
1770     // Connect all consumers of the tail nodes directly to the
1771     // output port of Split from which the chain started.
1772     for (const auto& link : tails) {
1773       UpdateConsumers(link.node,
1774                       link.port_origin == 0
1775                           ? split_name
1776                           : strings::StrCat(split_name, ":", link.port_origin));
1777     }
1778     return Status::OK();
1779   }
1780 
1781   // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)1782   void UpdateConsumers(NodeDef* node, const string& new_input) {
1783     const string& node_name = node->name();
1784     const auto consumers = ctx().node_map->GetOutputs(node_name);
1785     for (NodeDef* consumer : consumers) {
1786       for (int i = 0; i < consumer->input_size(); ++i) {
1787         if (consumer->input(i) == node_name &&
1788             consumer->name() != NodeName(new_input)) {
1789           consumer->set_input(i, new_input);
1790           ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
1791         }
1792       }
1793       AddToOptimizationQueue(consumer);
1794     }
1795   }
1796 
IsAlreadyOptimized(const NodeDef & node) const1797   bool IsAlreadyOptimized(const NodeDef& node) const {
1798     return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1799   }
1800 
1801  private:
1802   bool node_is_concat_;
1803   std::unordered_set<string> optimized_nodes_;
1804 };
1805 
1806 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1807  public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1808   explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1809                                  const ArithmeticOptimizerContext& ctx_ext)
1810       : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1811   ~RemoveIdempotentStage() override = default;
1812 
IsSupported(const NodeDef * node) const1813   bool IsSupported(const NodeDef* node) const override {
1814     return node->input_size() == 1 && IsIdempotent(*node) &&
1815            !IsInPreserveSet(*node);
1816   }
1817 
TrySimplify(NodeDef * node,string * simplified_node_name)1818   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1819     NodeDef* input;
1820     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1821     if (input->op() == node->op() && input->device() == node->device()) {
1822       *simplified_node_name = node->input(0);
1823     }
1824     return Status::OK();
1825   }
1826 };
1827 
1828 // Performs the conversion:
1829 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1830 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1831 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1832  public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1833   explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1834                                   const ArithmeticOptimizerContext& ctx_ext)
1835       : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1836   ~SqrtDivToRsqrtMulStage() override = default;
1837 
IsSupported(const NodeDef * node) const1838   bool IsSupported(const NodeDef* node) const override {
1839     // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1840     // for b == 0 would result in a / Inf instead of 0.
1841     return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1842   }
1843 
TrySimplify(NodeDef * node,string * simplified_node_name)1844   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1845     NodeDef* y;
1846     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1847     // Optimize only if divisor is a Sqrt whose output is not being consumed
1848     // elsewhere.
1849     if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1850         (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1851       if (IsXdivy(*node)) {
1852         // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1853         node->set_op("MulNoNan");
1854         node->mutable_input()->SwapElements(0, 1);
1855       } else {
1856         // div(a, sqrt(b)) => mul(a, rsqrt(b))
1857         node->set_op("Mul");
1858       }
1859       y->set_op("Rsqrt");
1860       AddToOptimizationQueue(node);
1861       AddToOptimizationQueue(y);
1862     }
1863     return Status::OK();
1864   }
1865 };
1866 
1867 // Performs the following conversion for real types:
1868 //   Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1869 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1870  public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1871   explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1872                                 const ArithmeticOptimizerContext& ctx_ext)
1873       : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1874   ~FuseSquaredDiffStage() override = default;
1875 
IsSupported(const NodeDef * node) const1876   bool IsSupported(const NodeDef* node) const override {
1877     return IsSquare(*node);
1878   }
1879 
TrySimplify(NodeDef * node,string * simplified_node_name)1880   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1881     NodeDef* b;
1882     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1883     // Optimize only if base is a Sub whose output is not being consumed
1884     // elsewhere.
1885     if (IsSub(*b) && !IsInPreserveSet(*b) &&
1886         (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1887       // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1888       // invalid.
1889       const DataType type = GetDataTypeFromAttr(*b, "T");
1890       if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128))
1891         return Status::OK();
1892       node->set_op("Identity");
1893       b->set_op("SquaredDifference");
1894       AddToOptimizationQueue(node);
1895       AddToOptimizationQueue(b);
1896     }
1897     return Status::OK();
1898   }
1899 };
1900 
1901 // Performs the conversion:
1902 // Log(Softmax(x)) => LogSoftmax(x)
1903 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1904  public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1905   explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1906                            const ArithmeticOptimizerContext& ctx_ext)
1907       : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1908   ~LogSoftmaxStage() override = default;
1909 
IsSupported(const NodeDef * node) const1910   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1911 
TrySimplify(NodeDef * node,string * simplified_node_name)1912   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1913     NodeDef* x;
1914     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1915     // Optimize only if arg is a Softmax whose output is not being consumed
1916     // elsewhere.
1917     if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1918         (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1919       // Log(Softmax(x)) => LogSoftmax(Identity(x))
1920       node->set_op("LogSoftmax");
1921       x->set_op("Identity");
1922       AddToOptimizationQueue(node);
1923       AddToOptimizationQueue(x);
1924     }
1925     return Status::OK();
1926   }
1927 };
1928 
1929 // Bypass redundant reshape nodes:
1930 //
1931 //   Reshape                    Reshape  <-+
1932 //      ^                                  |
1933 //      |                                  |
1934 //   Reshape       becomes      Reshape    |
1935 //      ^                                  |
1936 //      |                                  |
1937 //    input                      input  ---+
1938 //
1939 // Additionally,  Reshape and BroadcastTo nodes where the
1940 // input and target shapes are equal are bypassed.
1941 //
1942 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1943  public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1944   explicit RemoveRedundantReshapeOrBroadcastTo(
1945       const GraphOptimizerContext& ctx,
1946       const ArithmeticOptimizerContext& ctx_ext)
1947       : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1948                                  ctx_ext) {}
1949   ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1950 
IsSupported(const NodeDef * node) const1951   bool IsSupported(const NodeDef* node) const override {
1952     return (IsReshape(*node) || IsBroadcastTo(*node)) &&
1953            !IsInPreserveSet(*node);
1954   }
1955 
TrySimplify(NodeDef * node,string * simplified_node_name)1956   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1957     NodeDef* input;
1958     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1959 
1960     // 1. Bypass reshape followed by reshape.
1961     if (IsReshape(*node) && IsReshape(*input) && !IsInPreserveSet(*input)) {
1962       ForwardControlDependencies(node, {input});
1963       node->set_input(0, input->input(0));
1964       ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
1965       *simplified_node_name = node->name();
1966       AddToOptimizationQueue(node);
1967       return Status::OK();
1968     }
1969 
1970     // 2. If the reshape is a no-op, forward its input to its consumers, unless
1971     // it anchors a control dependency since we want to make sure that control
1972     // dependency is triggered.
1973     if (InputMatchesTargetShape(*node) && !HasControlInputs(*node)) {
1974       *simplified_node_name = node->input(0);
1975       return Status::OK();
1976     }
1977 
1978     return Status::OK();
1979   }
1980 
1981  private:
1982   // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)1983   bool InputMatchesTargetShape(const NodeDef& reshape) {
1984     const OpInfo::TensorProperties* reshape_props;
1985     const OpInfo::TensorProperties* input_props;
1986 
1987     if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
1988         !GetTensorProperties(reshape.input(0), &input_props).ok()) {
1989       return false;
1990     }
1991 
1992     return ShapesSymbolicallyEqual(input_props->shape(),
1993                                    reshape_props->shape());
1994   }
1995 };
1996 
1997 // Reorder casting and value-preserving ops if beneficial.
1998 //
1999 // Original motivation: A common pattern after the layout optimizer is
2000 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2001 // is beneficial to reorder the cast and the transpose to make the transpose
2002 // process smaller amount of data. More generally, this optimization converts
2003 //   Op(Cast(tensor, dst_type))
2004 // to
2005 //   Cast(Op(tensor), dst_type)
2006 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2007 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2008 // this optimization converts
2009 //   Cast(Op(tensor), dst_type)
2010 // to
2011 //   Op(Cast(tensor, dst_type))
2012 // when sizeof(tensor.type) > sizeof(dst_type)
2013 //
2014 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2015  public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2016   explicit ReorderCastLikeAndValuePreserving(
2017       const GraphOptimizerContext& ctx,
2018       const ArithmeticOptimizerContext& ctx_ext)
2019       : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2020                                  ctx_ext) {}
2021   ~ReorderCastLikeAndValuePreserving() override = default;
2022 
IsSupported(const NodeDef * node) const2023   bool IsSupported(const NodeDef* node) const override {
2024     return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2025            !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2026            !IsControlFlow(*node) && !IsInPreserveSet(*node);
2027   }
2028 
TrySimplify(NodeDef * consumer,string * simplified_node_name)2029   Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2030     NodeDef* producer;
2031     TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2032     const bool producer_is_cast = IsCastLike(*producer);
2033     const bool can_optimize =
2034         !IsCheckNumerics(*producer) &&
2035         ((producer_is_cast && IsValuePreserving(*consumer)) ||
2036          (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2037     if (!can_optimize || IsControlFlow(*producer) ||
2038         IsInPreserveSet(*producer) ||
2039         producer->device() != consumer->device()) {
2040       return Status::OK();
2041     }
2042 
2043     const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2044     const OpDef* cast_like_op_def = nullptr;
2045     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2046                                                          &cast_like_op_def));
2047     DataType cast_src_type;
2048     TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2049                                         &cast_src_type));
2050     DataType cast_dst_type;
2051     TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2052                                          &cast_dst_type));
2053     if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2054       return Status::OK();
2055     } else if (producer_is_cast &&
2056                DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2057       return Status::OK();
2058     } else if (!producer_is_cast &&
2059                DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2060       return Status::OK();
2061     }
2062 
2063     // Check that nodes were not already optimized.
2064     const string optimized_producer_name = OptimizedNodeName(
2065         ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2066     const string optimized_consumer_name = OptimizedNodeName(
2067         ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2068     const bool is_already_optimized =
2069         ctx().node_map->NodeExists(optimized_consumer_name) ||
2070         ctx().node_map->NodeExists(optimized_producer_name);
2071     if (is_already_optimized) {
2072       return Status::OK();
2073     }
2074 
2075     // Add copies of consumer and producer in reverse order.
2076     NodeDef* input;
2077     TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2078     // Create new producer node.
2079     NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2080     new_producer->set_input(0, producer->input(0));
2081     ctx().node_map->AddOutput(input->name(), new_producer->name());
2082 
2083     // Create new consumer node.
2084     NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2085     new_consumer->set_input(0, new_producer->name());
2086 
2087     NodeDef* new_value_preserving =
2088         producer_is_cast ? new_producer : new_consumer;
2089     const DataType new_input_type =
2090         producer_is_cast ? cast_src_type : cast_dst_type;
2091     // Update the input type of the value-preserving node. The input and
2092     // output types of the cast-like nodes remain the same.
2093     TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2094     // Make sure there is a kernel registered for the value preserving op
2095     // with the new input type.
2096     TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2097     ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2098 
2099     AddToOptimizationQueue(new_producer);
2100     *simplified_node_name = new_consumer->name();
2101 
2102     return Status::OK();
2103   }
2104 
2105  private:
2106   // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2107   Status SetInputType(DataType dtype, NodeDef* node) {
2108     const OpDef* op_def = nullptr;
2109     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2110     const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2111     const string& type_attr_name = input_arg.type_attr();
2112     if (type_attr_name.empty()) {
2113       if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2114         return errors::InvalidArgument("Could not set input type of ",
2115                                        node->op(), " op to ",
2116                                        DataTypeString(dtype));
2117       } else {
2118         // Op has fixed input type that already matches dtype.
2119         return Status::OK();
2120       }
2121     }
2122     SetDataTypeToAttr(dtype, type_attr_name, node);
2123     return Status::OK();
2124   }
2125   // This optimization can be dangerous on devices other than CPU and
2126   // GPU. The transpose might not be implemented for image.type, or
2127   // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2128   bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2129     using absl::StrContains;
2130 
2131     string task;
2132     string device;
2133 
2134     return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2135            (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2136   }
2137 
IsFixedSizeType(DataType dtype)2138   bool IsFixedSizeType(DataType dtype) {
2139     return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2140            !kQuantizedTypes.Contains(dtype);
2141   }
2142 };
2143 
2144 // Fold a multiply of a scalar into the following convolution. This folding
2145 // can jump across nodes that merely reorders data (such as reshape and
2146 // transpose). For example, we can optimize
2147 //
2148 //
2149 //         Conv2D                             Conv2D
2150 //        /      \                           /      \
2151 //    Transpose  weights*       ->     Transpose    Mul
2152 //       |                                |        /   \
2153 //      Mul                               |    weights  scale
2154 //     /   \                              |
2155 //   input  scale**                     input
2156 //
2157 //  *) weights must be a const
2158 // **) scale must be a const scalar
2159 //
2160 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2161 // constant-folded, also weights tend to be smaller than the activations.
2162 //
2163 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2164 // Conv?DBackpropInput.
2165 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2166  public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2167   explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2168                                 const ArithmeticOptimizerContext& ctx_ext)
2169       : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2170   ~FoldMultiplyIntoConv() override = default;
2171 
IsSupported(const NodeDef * node) const2172   bool IsSupported(const NodeDef* node) const override {
2173     return IsConv2D(*node) || IsConv3D(*node);
2174   }
2175 
TrySimplify(NodeDef * node,string * simplified_node_name)2176   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2177 #define TF_RETURN_IF_TRUE(...) \
2178   if ((__VA_ARGS__)) return Status::OK()
2179 
2180     NodeDef* conv = node;
2181 
2182     NodeDef* weights;
2183     TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2184 
2185     // Fold the multiply to conv only when the weights are constant, so the
2186     // multiply can be constant-folded.
2187     //
2188     // TODO(jingyue): When the weights aren't constant, this should also help
2189     // performance a bit and memory usage a lot, since the weights tend to be
2190     // smaller than the activations.
2191     TF_RETURN_IF_TRUE(!IsConstant(*weights));
2192 
2193     // Verify that this node was not already optimized.
2194     const string scaled_weights_node_name =
2195         OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2196                           strings::StrCat("scaled", "_", conv->name()));
2197 
2198     TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2199 
2200     // Find the tail of value preserving chain entering the Conv node.
2201     NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2202                                                   *ctx().nodes_to_preserve);
2203 
2204     NodeDef* source;
2205     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2206 
2207     // Check that value preserving chain is the only consumer of the Mul output.
2208     TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2209     TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2210     // And that Mul is not in the preserve set.
2211     TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2212 
2213     const NodeDef* mul = source;
2214     int input_idx = 0;
2215     int scale_idx = 1;
2216     NodeDef* scale;  // scalar multiplier for the input tensor
2217     NodeDef* input;
2218     TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2219     TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2220     if (!IsConstant(*scale) && IsConstant(*input)) {
2221       VLOG(3) << "Swapped inputs to mul";
2222       std::swap(scale_idx, input_idx);
2223       std::swap(scale, input);
2224     }
2225     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2226 
2227     // Check that one of the inputs to mul is a constant scalar.
2228     const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2229     bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2230                              scale_tensor.tensor_shape().dim_size() == 0;
2231     TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2232 
2233     // Check that 'scale * weight' can be const folded.
2234     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2235     TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2236     TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2237     TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2238                       weights->attr().at("dtype").type());
2239 
2240     // At this point all preconditions are met, and we safely do the rewrite.
2241     VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2242             << " mul=" << mul->name() << " weights=" << weights->name();
2243 
2244     // Create new node `scaled_weights`.
2245     NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2246     scaled_weights->set_op(source->op());
2247     scaled_weights->set_device(weights->device());
2248     (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2249     AddToOptimizationQueue(scaled_weights);
2250 
2251     // Link in its inputs.
2252     scaled_weights->add_input(conv->input(1));
2253     ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2254     scaled_weights->add_input(mul->input(scale_idx));
2255     ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2256     ForwardControlDependencies(scaled_weights, {source});
2257 
2258     // Update `conv`'s weights to `scaled_weights`.
2259     conv->set_input(1, scaled_weights->name());
2260     ctx().node_map->UpdateInput(conv->name(), weights->name(),
2261                                 scaled_weights->name());
2262     AddToOptimizationQueue(conv);
2263 
2264     // Update `tail` node to bypass `mul` because it's folded to the weights.
2265     tail->set_input(0, mul->input(input_idx));
2266     ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2267     AddToOptimizationQueue(tail);
2268     *simplified_node_name = conv->name();
2269 
2270     return Status::OK();
2271 #undef TF_RETURN_IF_TRUE
2272   }
2273 };
2274 
2275 // Fold Transpose into matrix multiplication.
2276 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2277  public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2278   explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2279                                    const ArithmeticOptimizerContext& ctx_ext)
2280       : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2281   ~FoldTransposeIntoMatMul() override = default;
2282 
IsSupported(const NodeDef * node) const2283   bool IsSupported(const NodeDef* node) const override {
2284     return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2285   }
2286 
TrySimplify(NodeDef * node,string * simplified_node_name)2287   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2288     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2289     const string optimized_node_name = OptimizedNodeName(matmul);
2290     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2291 
2292     NodeDef* a;
2293     NodeDef* b;
2294     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2295     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2296 
2297     bool is_complex = false;
2298     if (node->op() != "SparseMatMul") {
2299       const DataType type = GetDataTypeFromAttr(*node, "T");
2300       is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2301     }
2302 
2303     const std::set<string> foldable_transpose_ops =
2304         !is_complex
2305             ? std::set<string>{"ConjugateTranspose", "Transpose"}
2306             : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2307                                        : std::set<string>{"Transpose"});
2308 
2309     const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2310                                IsInnerMatrixTransposeNode(*a, ctx().node_map);
2311     const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2312                                IsInnerMatrixTransposeNode(*b, ctx().node_map);
2313     if (!a_is_foldable && !b_is_foldable) return Status::OK();
2314 
2315     NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2316 
2317     if (a_is_foldable) {
2318       const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2319       FlipBooleanAttr(attr_a, new_op);
2320       new_op->set_input(0, a->input(0));
2321       ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2322     } else {
2323       ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2324     }
2325 
2326     if (b_is_foldable) {
2327       const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2328       FlipBooleanAttr(attr_b, new_op);
2329       new_op->set_input(1, b->input(0));
2330       ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2331     } else {
2332       ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2333     }
2334 
2335     std::vector<const NodeDef*> deps_to_forward = {node};
2336     if (a_is_foldable) deps_to_forward.push_back(a);
2337     if (b_is_foldable) deps_to_forward.push_back(b);
2338     ForwardControlDependencies(new_op, deps_to_forward);
2339     *simplified_node_name = new_op->name();
2340 
2341     return Status::OK();
2342   }
2343 
2344  private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2345   void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2346     const bool old_value =
2347         !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2348     (*node->mutable_attr())[attr_name].set_b(!old_value);
2349   }
2350 
2351   template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2352   bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2353     const T n = perm.size();
2354     if (n < 2) {
2355       return false;
2356     }
2357     for (T i = 0; i < n - 2; ++i) {
2358       if (perm[i] != i) {
2359         return false;
2360       }
2361     }
2362     return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2363   }
2364 
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2365   bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2366                                   const NodeMap* node_map) {
2367     if (transpose_node.op() != "Transpose" &&
2368         transpose_node.op() != "ConjugateTranspose") {
2369       return false;
2370     }
2371     const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2372     std::vector<int> perm32;
2373     if (ValuesFromConstNode(*perm_node, &perm32)) {
2374       return IsInnerMatrixTranspose(perm32);
2375     }
2376     std::vector<int64> perm64;
2377     if (ValuesFromConstNode(*perm_node, &perm64)) {
2378       return IsInnerMatrixTranspose(perm64);
2379     }
2380     return false;
2381   }
2382 };
2383 
2384 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2385  public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2386   explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2387                                       const ArithmeticOptimizerContext& ctx_ext)
2388       : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2389   ~FoldConjugateIntoTranspose() override = default;
2390 
IsSupported(const NodeDef * node) const2391   bool IsSupported(const NodeDef* node) const override {
2392     return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2393   }
2394 
TrySimplify(NodeDef * node,string * simplified_node_name)2395   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2396     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2397     const string optimized_node_name = OptimizedNodeName(matmul);
2398     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2399 
2400     NodeDef* input;
2401     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2402 
2403     const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2404     const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2405 
2406     if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2407         IsConj(*conj_op)) {
2408       NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2409 
2410       // Flip the type of transpose op to absorb the conjugation.
2411       new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2412                                                        : "Transpose");
2413       new_op->set_input(0, input->input(0));
2414       ctx().node_map->UpdateInput(new_op->name(), node->name(),
2415                                   input->input(0));
2416       ForwardControlDependencies(new_op, {node, input});
2417       *simplified_node_name = new_op->name();
2418     }
2419 
2420     return Status::OK();
2421   }
2422 };
2423 
2424 // Replace Mul node with identical inputs with a Square.
2425 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2426  public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2427   explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2428                                 const ArithmeticOptimizerContext& ctx_ext)
2429       : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2430   ~ReplaceMulWithSquare() override = default;
2431 
IsSupported(const NodeDef * node) const2432   bool IsSupported(const NodeDef* node) const override {
2433     return IsAnyMul(*node) && node->input(0) == node->input(1);
2434   }
2435 
TrySimplify(NodeDef * node,string * simplified_node_name)2436   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2437     const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2438     const string optimized_node_name = OptimizedNodeName(mul);
2439     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2440 
2441     const DataType type = GetDataTypeFromAttr(*node, "T");
2442     bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2443 
2444     if (!is_complex || NodeIsOnCpu(*node)) {
2445       NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2446       new_square_node->set_op("Square");
2447       for (int i = 1; i < new_square_node->input_size(); ++i) {
2448         new_square_node->set_input(i - 1, new_square_node->input(i));
2449       }
2450       new_square_node->mutable_input()->RemoveLast();
2451       for (const string& input : new_square_node->input()) {
2452         ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2453       }
2454       *simplified_node_name = new_square_node->name();
2455     }
2456 
2457     return Status::OK();
2458   }
2459 };
2460 
2461 // Simplify aggregation (e.g. AddN) nodes:
2462 //
2463 // 1. Discard aggregate nodes with a single input and no control dependencies.
2464 //
2465 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
2466 //    deduping or other rewrites) so we can get rid of the sum entirely.
2467 //
2468 //    The expression (using AddN as an example of an aggregate op):
2469 //      AddN(x, x, x, ... ,x)
2470 //           <-- N terms -->
2471 //    can be rewritten to:
2472 //      Mul(Const(N), x))
2473 //
2474 class SimplifyAggregation : public ArithmeticOptimizerStage {
2475  public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2476   explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
2477                                const ArithmeticOptimizerContext& ctx_ext)
2478       : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
2479   ~SimplifyAggregation() override = default;
2480 
IsSupported(const NodeDef * node) const2481   bool IsSupported(const NodeDef* node) const override {
2482     return IsAggregate(*node) && HasRegularInputs(*node) &&
2483            GetDataTypeFromAttr(*node, "T") !=
2484                DT_VARIANT;  // TODO(b/119787146): Enable for variants.
2485   }
2486 
TrySimplify(NodeDef * node,string * simplified_node_name)2487   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2488     // 1. Discard aggregate nodes with a single input and no control deps.
2489     if (node->input_size() == 1) {
2490       *simplified_node_name = node->input(0);
2491       return Status::OK();
2492     }
2493 
2494     // 2. Rewrite aggregations of N >= 2 identical terms.
2495 
2496     // All non-control inputs must be identical.
2497     bool all_equal = true;
2498     int num_inputs = 1;
2499     for (int i = 1; i < node->input_size(); ++i) {
2500       if (IsControlInput(node->input(i))) break;
2501       ++num_inputs;
2502       if (node->input(i) != node->input(0)) {
2503         all_equal = false;
2504         break;
2505       }
2506     }
2507     if (!all_equal) return Status::OK();
2508 
2509     // And node should not be optimized earlier.
2510     const NodeScopeAndName node_scope_and_name =
2511         ParseNodeScopeAndName(node->name());
2512     const string optimized_const_name =
2513         OptimizedNodeName(node_scope_and_name, "Const");
2514     const string optimized_mul_name =
2515         OptimizedNodeName(node_scope_and_name, "Mul");
2516 
2517     bool is_already_optimized =
2518         ctx().node_map->NodeExists(optimized_const_name) ||
2519         ctx().node_map->NodeExists(optimized_mul_name);
2520 
2521     if (is_already_optimized) return Status::OK();
2522 
2523     // At this point all preconditions are met, and we safely do the rewrite.
2524     VLOG(3) << "Simplify aggregation with identical inputs: node="
2525             << node->name() << " num_inputs=" << num_inputs;
2526 
2527     // 1. Create constant node with value N.
2528     const auto type = GetDataTypeFromAttr(*node, "T");
2529     Tensor t(type, TensorShape({}));
2530     Status status = SetTensorValue(type, num_inputs, &t);
2531     if (!status.ok()) {
2532       return errors::Internal("Failed to create const node: ",
2533                               status.error_message());
2534     }
2535 
2536     TensorValue value(&t);
2537     NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
2538     status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
2539                                             new_const_node);
2540     if (!status.ok()) {
2541       return errors::Internal("Failed to create const node: ",
2542                               status.error_message());
2543     }
2544     new_const_node->set_device(node->device());
2545     MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
2546                          ctx().optimized_graph, ctx().node_map);
2547     AddToOptimizationQueue(new_const_node);
2548 
2549     // 2. Replace the aggregate node with Mul(Const(N), x).
2550     NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
2551     new_mul_node->set_op("Mul");
2552     new_mul_node->set_device(node->device());
2553     SetDataTypeToAttr(type, "T", new_mul_node);
2554     new_mul_node->add_input(new_const_node->name());
2555     ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
2556     new_mul_node->add_input(node->input(0));
2557     ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
2558 
2559     ForwardControlDependencies(new_mul_node, {node});
2560     *simplified_node_name = new_mul_node->name();
2561 
2562     return Status::OK();
2563   }
2564 };
2565 
2566 class ConvertPowStage : public ArithmeticOptimizerStage {
2567  public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2568   explicit ConvertPowStage(const GraphOptimizerContext& ctx,
2569                            const ArithmeticOptimizerContext& ctx_ext)
2570       : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
2571 
IsSupported(const NodeDef * node) const2572   bool IsSupported(const NodeDef* node) const override {
2573     return IsPow(*node) &&
2574            ctx().graph_properties->HasOutputProperties(node->name()) &&
2575            ctx().graph_properties->HasInputProperties(node->name());
2576   }
2577 
TrySimplify(NodeDef * node,string * simplified_node_name)2578   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2579     Tensor pow;
2580     if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
2581     complex128 prev, curr;
2582     for (int i = 0; i < pow.NumElements(); ++i) {
2583       if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
2584         // input data type is not supported by Pow. Skip.
2585         return Status::OK();
2586       }
2587       if (i != 0 && curr != prev) {
2588         // pow has different values on different elements. Skip.
2589         return Status::OK();
2590       }
2591       prev = curr;
2592     }
2593     NodeDef *x, *y;
2594     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
2595     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
2596     const auto& value_props =
2597         ctx().graph_properties->GetInputProperties(node->name())[0];
2598     const TensorShapeProto& output_shape =
2599         ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
2600     if (curr == complex128(2, 0)) {
2601       node->set_op("Square");
2602       node->set_input(1, AsControlDependency(y->name()));
2603       AddToOptimizationQueue(node);
2604       AddToOptimizationQueue(y);
2605     } else if (curr == complex128(3, 0)) {
2606       // TODO(courbet): Use 'Cube' when it's added to TF ops.
2607       if (NodeIsOnCpu(*node)) {
2608         // We create an inner square node: inner_square = square(x)
2609         const NodeScopeAndName scope_and_name =
2610             ParseNodeScopeAndName(node->name());
2611         const string inner_square_name =
2612             OptimizedNodeName(scope_and_name, "_inner");
2613         NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
2614         if (inner_square_node == nullptr) {
2615           inner_square_node = AddCopyNode(inner_square_name, node);
2616           inner_square_node->set_op("Square");
2617           inner_square_node->mutable_input()->RemoveLast();
2618         }
2619         ctx().node_map->AddOutput(x->name(), inner_square_node->name());
2620         // We modify `node`: node = mul(x, inner_square);
2621         node->set_op("Mul");
2622         node->set_input(1, inner_square_node->name());
2623         node->add_input(AsControlDependency(y->name()));
2624 
2625         AddToOptimizationQueue(node);
2626         AddToOptimizationQueue(inner_square_node);
2627         AddToOptimizationQueue(y);
2628       }
2629     } else if (curr == complex128(1, 0) &&
2630                ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2631       // Pow could be used to broadcast, so make sure the shapes of the two
2632       // arguments are identical before replacing Pow with Identity.
2633       node->set_op("Identity");
2634       node->set_input(1, AsControlDependency(y->name()));
2635       AddToOptimizationQueue(node);
2636       AddToOptimizationQueue(y);
2637     } else if (curr == complex128(0.5, 0)) {
2638       node->set_op("Sqrt");
2639       node->set_input(1, AsControlDependency(y->name()));
2640       AddToOptimizationQueue(node);
2641       AddToOptimizationQueue(y);
2642     } else if (curr == complex128(0, 0) &&
2643                ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
2644                PartialTensorShape(output_shape).IsFullyDefined()) {
2645       const auto dtype = node->attr().at("T").type();
2646       Tensor ones(dtype, output_shape);
2647       for (int i = 0; i < ones.NumElements(); ++i) {
2648         TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
2649       }
2650       node->set_op("Const");
2651       (*node->mutable_attr())["dtype"].set_type(dtype);
2652       node->mutable_attr()->erase("T");
2653       ones.AsProtoTensorContent(
2654           (*node->mutable_attr())["value"].mutable_tensor());
2655       node->set_input(0, AsControlDependency(x->name()));
2656       node->set_input(1, AsControlDependency(y->name()));
2657       AddToOptimizationQueue(node);
2658       AddToOptimizationQueue(x);
2659       AddToOptimizationQueue(y);
2660     } else if (curr == complex128(-0.5, 0)) {
2661       node->set_op("Rsqrt");
2662       node->set_input(1, AsControlDependency(y->name()));
2663       AddToOptimizationQueue(node);
2664       AddToOptimizationQueue(y);
2665     } else if (curr == complex128(-1, 0)) {
2666       node->set_op("Reciprocal");
2667       node->set_input(1, AsControlDependency(y->name()));
2668       AddToOptimizationQueue(node);
2669       AddToOptimizationQueue(y);
2670     }
2671     return Status::OK();
2672   }
2673 
2674  private:
SetElementToOne(int i,Tensor * t)2675   Status SetElementToOne(int i, Tensor* t) {
2676     switch (t->dtype()) {
2677       case DT_INT32:
2678         t->flat<int32>()(i) = 1;
2679         return Status::OK();
2680       case DT_INT64:
2681         t->flat<int64>()(i) = 1L;
2682         return Status::OK();
2683       case DT_FLOAT:
2684         t->flat<float>()(i) = 1.0f;
2685         return Status::OK();
2686       case DT_DOUBLE:
2687         t->flat<double>()(i) = 1.0;
2688         return Status::OK();
2689       case DT_COMPLEX64:
2690         t->flat<complex64>()(i) = complex64(1);
2691         return Status::OK();
2692       case DT_COMPLEX128:
2693         t->flat<complex128>()(i) = complex128(1);
2694         return Status::OK();
2695       default:
2696         return errors::InvalidArgument("Invalid data type: ", t->dtype());
2697     }
2698   }
2699 };
2700 
2701 class ConvertLog1pStage : public ArithmeticOptimizerStage {
2702  public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2703   explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
2704                              const ArithmeticOptimizerContext& ctx_ext)
2705       : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
2706   ~ConvertLog1pStage() override = default;
2707 
IsSupported(const NodeDef * node) const2708   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
2709 
TrySimplify(NodeDef * node,string * simplified_node_name)2710   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2711     NodeDef* input;
2712     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2713     if (!IsAdd(*input)) {
2714       return Status::OK();
2715     }
2716 
2717     if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
2718       return Status::OK();
2719     }
2720 
2721     bool modified = false;
2722     TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
2723     if (!modified) {
2724       TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
2725     }
2726     if (modified) {
2727       *simplified_node_name = node->name();
2728     }
2729     return Status::OK();
2730   }
2731 
2732  private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)2733   Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
2734                              bool* modified) {
2735     const auto& t =
2736         ctx().graph_properties->GetInputProperties(add_node->name())[i];
2737     const auto& c =
2738         ctx().graph_properties->GetInputProperties(add_node->name())[j];
2739     for (int k = 0; k < c.shape().dim_size(); ++k) {
2740       // Skip if c shape is not fully determined.
2741       if (c.shape().dim(k).size() < 0) {
2742         return Status::OK();
2743       }
2744     }
2745     TensorShapeProto broadcast_shape;
2746     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2747       return Status::OK();
2748     }
2749     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2750       // skip if the non-constant tensor doesn't have the same shape after
2751       // broadcast.
2752       return Status::OK();
2753     }
2754     Tensor constant;
2755     if (GetTensorFromConstNode(add_node->input(j), &constant)) {
2756       complex128 element;
2757       // TODO(rmlarsen): Refactor the more general IsOnes from
2758       // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
2759       // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
2760       // and Neg(-1) to 1.
2761       for (int k = 0; k < constant.NumElements(); ++k) {
2762         if (!GetElementUnexhaustive(constant, k,
2763                                     {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2764                                      DT_COMPLEX64, DT_COMPLEX128},
2765                                     &element)) {
2766           // input data type is not supported by log1p. Skip.
2767           return Status::OK();
2768         }
2769         if (element != complex128(1)) {
2770           // current element is not 1. Skip.
2771           return Status::OK();
2772         }
2773       }
2774       NodeDef *x, *y;
2775       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
2776       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
2777       node->set_op("Log1p");
2778       node->set_input(0, add_node->input(i));
2779       node->add_input(AsControlDependency(y->name()));
2780       ForwardControlDependencies(node, {add_node});
2781 
2782       AddToOptimizationQueue(node);
2783       AddToOptimizationQueue(add_node);
2784       AddToOptimizationQueue(x);
2785       AddToOptimizationQueue(y);
2786       *modified = true;
2787     }
2788     return Status::OK();
2789   }
2790 };
2791 
2792 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
2793  public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2794   explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
2795                              const ArithmeticOptimizerContext& ctx_ext)
2796       : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
2797   ~ConvertExpm1Stage() override = default;
2798 
IsSupported(const NodeDef * node) const2799   bool IsSupported(const NodeDef* node) const override {
2800     if (!IsSub(*node)) return false;
2801 
2802     NodeDef* input;
2803     if (!GetInputNode(node->input(0), &input).ok()) return false;
2804 
2805     return IsExp(*input);
2806   }
2807 
TrySimplify(NodeDef * node,string * simplified_node_name)2808   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2809     if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
2810       return Status::OK();
2811     }
2812     const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
2813     const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
2814     TensorShapeProto broadcast_shape;
2815     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2816       return Status::OK();
2817     }
2818     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2819       // skip if the non-constant tensor doesn't have the same shape after
2820       // broadcast.
2821       return Status::OK();
2822     }
2823     Tensor constant;
2824     if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
2825     // TODO(rmlarsen): Use the more general IsOnes helper here.
2826     complex128 element;
2827     for (int k = 0; k < constant.NumElements(); ++k) {
2828       if (!GetElementUnexhaustive(constant, k,
2829                                   {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2830                                    DT_COMPLEX64, DT_COMPLEX128},
2831                                   &element)) {
2832         // input data type is not supported by expm1. Skip.
2833         return Status::OK();
2834       }
2835       if (element != complex128(1)) {
2836         // current element is not 1. Skip.
2837         return Status::OK();
2838       }
2839     }
2840     NodeDef* exp;
2841     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
2842     NodeDef *exp_input, *ones;
2843     TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
2844     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2845     node->set_op("Expm1");
2846     node->set_input(0, exp->input(0));
2847     node->set_input(1, AsControlDependency(ones->name()));
2848     ForwardControlDependencies(node, {exp});
2849 
2850     AddToOptimizationQueue(node);
2851     AddToOptimizationQueue(exp);
2852     AddToOptimizationQueue(exp_input);
2853     AddToOptimizationQueue(ones);
2854     *simplified_node_name = node->name();
2855     return Status::OK();
2856   }
2857 };
2858 
2859 // Performs conversions like:
2860 // Max(Sqrt(x)) => Sqrt(Max(x))
2861 // Checks for a max/min reduction over element-wise monotonic functions, such
2862 // as Sqrt, Sigmoid, Tanh, etc.
2863 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
2864  public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2865   explicit OptimizeMaxOrMinOfMonotonicStage(
2866       const GraphOptimizerContext& ctx,
2867       const ArithmeticOptimizerContext& ctx_ext)
2868       : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
2869                                  ctx_ext) {}
2870   ~OptimizeMaxOrMinOfMonotonicStage() override = default;
2871 
IsSupported(const NodeDef * node) const2872   bool IsSupported(const NodeDef* node) const override {
2873     return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
2874            IsArgMax(*node) || IsArgMin(*node);
2875   }
2876 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)2877   Status TrySimplify(NodeDef* reduction_node,
2878                      string* simplified_node_name) override {
2879     if (IsInPreserveSet(*reduction_node)) {
2880       return Status::OK();
2881     }
2882 
2883     NodeDef* inner_function;
2884     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
2885 
2886     NodeDef* inner_function_input = nullptr;
2887     if (inner_function->input_size() > 0) {
2888       TF_RETURN_IF_ERROR(
2889           GetInputNode(inner_function->input(0), &inner_function_input));
2890     }
2891 
2892     // Optimize only if:
2893     // 0. inner_function is not in the preserve set,
2894     // 1. inner_function's Op is element-wise monotonic
2895     // 2. inner_function's output is not being consumed elsewhere.
2896     // 3. is monotonic increasing if reduction_node is a pooling operation
2897     //    since we don't have MinPool operations.
2898     // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
2899     //    or BiasAdd. This pattern will be fused later by remapper.
2900     auto can_be_fused_by_remapper = [](const NodeDef& consumer,
2901                                        const NodeDef& producer) -> bool {
2902       if (IsRelu(consumer) || IsRelu6(consumer)) {
2903         if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
2904           return true;
2905         }
2906       }
2907       return false;
2908     };
2909     bool is_non_decreasing = false;
2910     if (!IsInPreserveSet(*inner_function) &&
2911         IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
2912         ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
2913         (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
2914         !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
2915       // Swap the first inputs of the inner function Op & the reduction Op.
2916       NodeDef* inner_input;
2917       TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
2918       reduction_node->set_input(0, inner_input->name());
2919       ctx().node_map->UpdateInput(reduction_node->name(),
2920                                   inner_function->name(), inner_input->name());
2921       inner_function->set_input(0, reduction_node->name());
2922       UpdateConsumers(reduction_node, inner_function->name());
2923       ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
2924                                   reduction_node->name());
2925       if (!is_non_decreasing) {
2926         // Flip Min<->Max if the function is non-increasing, e.g.
2927         // Max(Neg(x)) = Neg(Min(x)).
2928         const string opposite = FlipMinMax(*reduction_node);
2929         reduction_node->set_op(opposite);
2930       }
2931 
2932       if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
2933         // ArgMax(Sqrt(x)) = ArgMax(x)
2934         inner_function->set_op("Identity");
2935       }
2936 
2937       AddToOptimizationQueue(reduction_node);
2938       AddToOptimizationQueue(inner_function);
2939       AddToOptimizationQueue(inner_input);
2940     }
2941     return Status::OK();
2942   }
2943 
UpdateConsumers(NodeDef * node,const string & new_input)2944   void UpdateConsumers(NodeDef* node, const string& new_input) {
2945     const string& node_name = node->name();
2946     const auto consumers = ctx().node_map->GetOutputs(node_name);
2947     for (NodeDef* consumer : consumers) {
2948       for (int i = 0; i < consumer->input_size(); ++i) {
2949         if (consumer->input(i) == node_name &&
2950             consumer->name() != NodeName(new_input)) {
2951           consumer->set_input(i, new_input);
2952           ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
2953         }
2954       }
2955       AddToOptimizationQueue(consumer);
2956     }
2957   }
2958 
2959  private:
FlipMinMax(const NodeDef & node)2960   string FlipMinMax(const NodeDef& node) {
2961     const string& op = node.op();
2962     if (IsAnyMax(node) || IsArgMax(node)) {
2963       return str_util::StringReplace(op, "Max", "Min", false);
2964     } else {
2965       return str_util::StringReplace(op, "Min", "Max", false);
2966     }
2967   }
2968 };
2969 
2970 // Replace a chain of type&shape preserving unary ops with a
2971 // '_UnaryOpsComposition' node.
2972 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
2973 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
2974 class UnaryOpsComposition : public ArithmeticOptimizerStage {
2975  public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2976   explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
2977                                const ArithmeticOptimizerContext& ctx_ext)
2978       : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
2979     // WARN: This should be consistent with unary_ops_composition.cc.
2980     // clang-format off
2981     supported_ops_ = {// Ops defined via Eigen scalar ops.
2982                       {"Abs",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2983                       {"Acos",       {DT_FLOAT,          DT_DOUBLE}},
2984                       {"Acosh",      {DT_FLOAT,          DT_DOUBLE}},
2985                       {"Asin",       {DT_FLOAT,          DT_DOUBLE}},
2986                       {"Asinh",      {DT_FLOAT,          DT_DOUBLE}},
2987                       {"Atan",       {DT_FLOAT,          DT_DOUBLE}},
2988                       {"Atanh",      {DT_FLOAT,          DT_DOUBLE}},
2989                       {"Ceil",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2990                       {"Cos",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2991                       {"Cosh",       {DT_FLOAT,          DT_DOUBLE}},
2992                       {"Expm1",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2993                       {"Exp",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2994                       {"Floor",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2995                       {"Inv",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2996                       {"Log",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2997                       {"Log1p",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2998                       {"Neg",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2999                       {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3000                       {"Rint",       {DT_FLOAT,          DT_DOUBLE}},
3001                       {"Round",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3002                       {"Rsqrt",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3003                       {"Sigmoid",    {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3004                       {"Sin",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3005                       {"Sinh",       {DT_FLOAT,          DT_DOUBLE}},
3006                       {"Sqrt",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3007                       {"Square",     {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3008                       {"Tan",        {DT_FLOAT,          DT_DOUBLE}},
3009                       {"Tanh",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3010                       // Additional ops that are not part of the Eigen.
3011                       {"Elu",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3012                       {"Relu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3013                       {"Relu6",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3014                       {"Selu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3015     // clang-format on
3016   }
3017   ~UnaryOpsComposition() override = default;
3018 
IsSupported(const NodeDef * node) const3019   bool IsSupported(const NodeDef* node) const override {
3020     return CanOptimize(*node) &&
3021            // Check that this node was not already a root of a fused chain. If
3022            // graph optimization runs twice without pruning in between,
3023            // fused_nodes_ will not have this information.
3024            !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3025   }
3026 
TrySimplify(NodeDef * root,string * simplified_node_name)3027   Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3028     TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3029     DataType dtype = root->attr().at("T").type();
3030 
3031     // Keep a trace of all supported input nodes that can be fused together.
3032     std::vector<string> op_nodes = {root->name()};
3033     std::vector<string> op_names = {root->op()};
3034 
3035     // Check if we should follow input(0) while building an op composition.
3036     const auto predicate_fn = [&](const NodeDef& input) {
3037       if (input.name() == root->name()) return true;
3038 
3039       bool follow_input_node =
3040           dtype == GetDataTypeFromAttr(input, "T") &&
3041           NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3042           CanOptimize(input);
3043 
3044       if (follow_input_node) {
3045         op_nodes.push_back(input.name());
3046         op_names.push_back(input.op());
3047       }
3048 
3049       return follow_input_node;
3050     };
3051 
3052     NodeDef* last_op = GetTailOfChain(
3053         *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3054 
3055     // We were not able to find a chain that can be replaced.
3056     if (op_names.size() == 1) return Status::OK();
3057 
3058     // Do not add fused nodes to any other chain.
3059     std::for_each(op_nodes.begin(), op_nodes.end(),
3060                   [this](const string& name) { AddToFusedNodes(name); });
3061 
3062     // Reverse the trace to get correct composition computation order.
3063     std::reverse(op_names.begin(), op_names.end());
3064 
3065     VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3066             << absl::StrJoin(op_names, ", ") << "]";
3067 
3068     NodeDef* composition_node = ctx().optimized_graph->add_node();
3069     composition_node->set_name(OptimizedNodeName(*root));
3070     composition_node->set_op("_UnaryOpsComposition");
3071     composition_node->add_input(last_op->input(0));
3072     composition_node->set_device(root->device());
3073 
3074     auto attr = composition_node->mutable_attr();
3075     SetAttrValue(dtype, &(*attr)["T"]);
3076     SetAttrValue(op_names, &(*attr)["op_names"]);
3077 
3078     ctx().node_map->AddNode(composition_node->name(), composition_node);
3079     ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3080                               composition_node->name());
3081 
3082     *simplified_node_name = composition_node->name();
3083 
3084     return Status::OK();
3085   }
3086 
3087  private:
CanOptimize(const NodeDef & node) const3088   bool CanOptimize(const NodeDef& node) const {
3089     DataType dtype = GetDataTypeFromAttr(node, "T");
3090     if (!IsSupported(node.op(), dtype)) {
3091       return false;
3092     }
3093     if (IsInPreserveSet(node)) {
3094       return false;
3095     }
3096     if (!NodeIsOnCpu(node)) {
3097       return false;
3098     }
3099     if (NodeIsAlreadyFused(node)) {
3100       return false;
3101     }
3102     return !(IsDrivenByControlDependency(node) ||
3103              DrivesControlDependency(node));
3104   }
3105 
NodeIsAlreadyFused(const NodeDef & node) const3106   bool NodeIsAlreadyFused(const NodeDef& node) const {
3107     return fused_nodes_.count(node.name()) > 0;
3108   }
3109 
OptimizedNodeName(const NodeDef & node) const3110   string OptimizedNodeName(const NodeDef& node) const {
3111     return strings::StrCat(node.name(), "/unary_ops_composition");
3112   }
3113 
AddToFusedNodes(const string & name)3114   void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3115 
3116   // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3117   bool IsSupported(const string& op_name, DataType dtype) const {
3118     const auto it = supported_ops_.find(op_name);
3119     return it != supported_ops_.end() && it->second.count(dtype) > 0;
3120   }
3121 
3122   std::unordered_map<string, std::set<DataType>> supported_ops_;
3123   std::unordered_set<string> fused_nodes_;
3124 };
3125 
3126 // Replace operations of the form:
3127 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3128 // with
3129 //    a_i
3130 // when the strided slice index `i` is applied in the k'th axis.
3131 //
3132 // Similarly, replace operations of the form:
3133 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3134 // with
3135 //    expand_dims(a_i, axis=k)
3136 // where the slice operator can be StridedSlice or Slice.
3137 //
3138 // TODO(ebrevdo): Extend to also replace operations of the form
3139 //    concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3140 // with
3141 //    a_i,
3142 // when
3143 //    s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3144 // and slicing is in the k'th axis.
3145 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3146  public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3147   explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3148                                     const ArithmeticOptimizerContext& ctx_ext)
3149       : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3150                                  ctx_ext) {}
3151   ~RemoveStackSliceSameAxis() override = default;
3152 
IsSupported(const NodeDef * node) const3153   bool IsSupported(const NodeDef* node) const override {
3154     return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3155   }
3156 
TrySimplify(NodeDef * node,string * simplified_node_name)3157   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3158     // *node is a StridedSlice NodeDef.
3159     NodeDef* pack;
3160 
3161     // Get the input and see if it's a Pack op.
3162     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3163     if (!IsPack(*pack)) return Status::OK();
3164 
3165     bool return_early;
3166     PartialTensorShape pack_output_shape;
3167     int pack_axis;
3168     TF_RETURN_IF_ERROR(
3169         CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3170     if (return_early) return Status::OK();
3171 
3172     int64 slice_start_value;
3173     bool found;
3174     bool must_expand_dims;
3175     TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3176                                     &slice_start_value, &found,
3177                                     &must_expand_dims));
3178     if (!found) return Status::OK();
3179 
3180     return RewriteGraph(node, pack, slice_start_value, pack_axis,
3181                         must_expand_dims, simplified_node_name);
3182   }
3183 
3184  protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3185   Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3186                      PartialTensorShape* pack_output_shape, int* pack_axis,
3187                      bool* return_early) {
3188     *return_early = true;
3189     TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3190 
3191     *pack_axis = pack->attr().at("axis").i();
3192     auto slice_properties =
3193         ctx().graph_properties->GetInputProperties(node->name());
3194     if (slice_properties.empty() ||
3195         slice_properties[0].shape().unknown_rank()) {
3196       return Status::OK();
3197     }
3198     *pack_output_shape = slice_properties[0].shape();
3199     const int pack_output_rank = pack_output_shape->dims();
3200     if (*pack_axis < 0) {
3201       *pack_axis += pack_output_rank;
3202     }
3203     if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3204       return errors::InvalidArgument(
3205           "Pack node (", pack->name(),
3206           ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3207     }
3208     *return_early = false;
3209     return Status::OK();
3210   }
3211 
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3212   Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3213                       const PartialTensorShape& pack_output_shape,
3214                       int pack_axis, int64* slice_start_value, bool* found,
3215                       bool* must_expand_dims) {
3216     *found = false;
3217     if (IsSlice(*node)) {
3218       *must_expand_dims = true;
3219       return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3220                                 slice_start_value, found);
3221     } else {
3222       return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3223                                  slice_start_value, found, must_expand_dims);
3224     }
3225   }
3226 
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found)3227   Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3228                             const PartialTensorShape& pack_output_shape,
3229                             int pack_axis, int64* slice_start_value,
3230                             bool* found) {
3231     NodeDef* slice_begin;
3232     NodeDef* slice_size;
3233     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3234     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3235     for (const auto* n : {slice_begin, slice_size}) {
3236       if (!IsReallyConstant(*n)) return Status::OK();
3237     }
3238 
3239     Tensor slice_begin_t;
3240     Tensor slice_size_t;
3241     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3242     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3243       return Status::OK();
3244     }
3245     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3246     if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3247       return Status::OK();
3248     }
3249 
3250     auto copy_tensor_values_to_vector =
3251         [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3252           if (t.dtype() == DT_INT32) {
3253             auto t_flat = t.flat<int32>();
3254             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3255           } else if (t.dtype() == DT_INT64) {
3256             auto t_flat = t.flat<int64>();
3257             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3258           } else {
3259             return errors::InvalidArgument("Node ", node->name(),
3260                                            " has invalid type for Index attr: ",
3261                                            DataTypeString(t.dtype()));
3262           }
3263           return Status::OK();
3264         };
3265 
3266     gtl::InlinedVector<int64, 4> slice_begin_vec;
3267     gtl::InlinedVector<int64, 4> slice_size_vec;
3268     TF_RETURN_IF_ERROR(
3269         copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3270     TF_RETURN_IF_ERROR(
3271         copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3272 
3273     if (slice_begin_vec.size() != slice_size_vec.size()) {
3274       return errors::InvalidArgument("Node ", node->name(),
3275                                      " has mismatched lengths for begin (",
3276                                      slice_begin_vec.size(), ") and size (",
3277                                      slice_size_vec.size(), ") vectors.");
3278     }
3279     int slice_begin_vec_size = slice_begin_vec.size();
3280     if (!pack_output_shape.unknown_rank() &&
3281         slice_begin_vec_size != pack_output_shape.dims()) {
3282       return Status::OK();
3283     }
3284     if (pack_axis >= slice_begin_vec_size) {
3285       return errors::InvalidArgument(
3286           "Input to node ", node->name(), " had pack_axis ", pack_axis,
3287           " but rank was ", slice_begin_vec_size, ".");
3288     }
3289 
3290     *slice_start_value = slice_begin_vec[pack_axis];
3291     if (slice_size_vec[pack_axis] != 1) {
3292       // Not slicing a single value out.
3293       return Status::OK();
3294     }
3295 
3296     for (int i = 0; i < slice_begin_vec_size; ++i) {
3297       if (i != pack_axis) {
3298         if (slice_begin_vec[i] != 0 ||
3299             !(slice_size_vec[i] == -1 ||
3300               slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3301           // Not slicing on the same axis as the Pack op.
3302           return Status::OK();
3303         }
3304       }
3305     }
3306 
3307     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3308       return errors::InvalidArgument(
3309           "Node ", node->name(), " requested invalid slice index ",
3310           *slice_start_value, " on axis ", pack_axis,
3311           " from tensor of shape: ", pack_output_shape.DebugString());
3312     }
3313 
3314     *found = true;  // slice_start_value is valid.
3315     return Status::OK();
3316   }
3317 
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3318   Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3319                              const PartialTensorShape& pack_output_shape,
3320                              int pack_axis, int64* slice_start_value,
3321                              bool* found, bool* must_expand_dims) {
3322     TF_RETURN_IF_ERROR(
3323         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3324                                 "new_axis_mask", "shrink_axis_mask"}));
3325 
3326     const int begin_mask = node->attr().at("begin_mask").i();
3327     const int end_mask = node->attr().at("end_mask").i();
3328     const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3329     const int new_axis_mask = node->attr().at("new_axis_mask").i();
3330     const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3331 
3332     // Check that the StridedSlice is one of these at pack_axis:
3333     //   [..., i, ...]
3334     //   [..., i:i+1, ...]
3335     //   [..., :1, ...]
3336     //   [..., -1:, ...]
3337     ///  [..., s_{pack_axis}-1:, ...]
3338     NodeDef* slice_begin;
3339     NodeDef* slice_end;
3340     NodeDef* slice_strides;
3341     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3342     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3343     TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3344 
3345     for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3346       if (!IsReallyConstant(*n)) return Status::OK();
3347     }
3348 
3349     Tensor slice_begin_t;
3350     Tensor slice_end_t;
3351     Tensor slice_strides_t;
3352 
3353     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3354     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3355       return Status::OK();
3356     }
3357     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3358     if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3359       return Status::OK();
3360     }
3361     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3362     if (!slice_strides_t.FromProto(
3363             slice_strides->attr().at("value").tensor())) {
3364       return Status::OK();
3365     }
3366     TensorShape processing_shape;
3367     TensorShape final_shape;
3368     bool is_identity;
3369     bool is_simple_slice;
3370     bool slice_dim0;
3371     gtl::InlinedVector<int64, 4> slice_begin_vec;
3372     gtl::InlinedVector<int64, 4> slice_end_vec;
3373     gtl::InlinedVector<int64, 4> slice_strides_vec;
3374     TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3375         &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3376         begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3377         &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3378         &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3379 
3380     if (!is_simple_slice) return Status::OK();
3381 
3382     int begin_index = -1;
3383     int64 begin_value = 0;
3384     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3385       const int64 v = slice_begin_vec[i];
3386       if (v != 0) {
3387         if (begin_index != -1) {
3388           // At least two start values that are nonzero.
3389           return Status::OK();
3390         }
3391         begin_index = i;
3392         begin_value = v;
3393       }
3394     }
3395 
3396     int end_index = -1;
3397     int64 end_value = 0;
3398     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3399       const int64 v = slice_end_vec[i];
3400       if (v != pack_output_shape.dim_size(i)) {
3401         if (end_index != -1) {
3402           // At least two end values that are nonzero.
3403           return Status::OK();
3404         }
3405         end_index = i;
3406         end_value = v;
3407       }
3408     }
3409 
3410     if (begin_index == -1 && end_index == -1) return Status::OK();
3411     if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3412       // Somehow received different axes for begin/end slicing
3413       return Status::OK();
3414     }
3415     const int slice_axis = (begin_index == -1) ? end_index : begin_index;
3416     if (slice_axis != pack_axis) {
3417       // Not slicing on the same axis as the Pack op.
3418       return Status::OK();
3419     }
3420     *slice_start_value = (begin_index == -1) ? 0 : begin_value;
3421     const int64 slice_end_value =
3422         (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
3423     if (slice_end_value != *slice_start_value + 1) {
3424       // Not slicing a single value out.
3425       return Status::OK();
3426     }
3427 
3428     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3429       return errors::InvalidArgument(
3430           "Node ", node->name(), " requested invalid slice index ",
3431           *slice_start_value, " on axis ", slice_axis,
3432           " from tensor of shape: ", pack_output_shape.DebugString());
3433     }
3434 
3435     if (shrink_axis_mask == 0) {
3436       *must_expand_dims = true;
3437     } else if (shrink_axis_mask == (1 << slice_axis)) {
3438       *must_expand_dims = false;
3439     } else {
3440       // Shrinking on a different axis from the one that we are slicing on.
3441       return Status::OK();
3442     }
3443 
3444     *found = true;  // slice_start_value is valid.
3445     return Status::OK();
3446   }
3447 
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64 slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)3448   Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
3449                       int64 slice_start_value, int pack_axis,
3450                       bool must_expand_dims, string* simplified_node_name) {
3451     const string& input_slice = pack->input(slice_start_value);
3452 
3453     const OpInfo::TensorProperties* input_slice_properties;
3454     TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
3455                                            &input_slice_properties));
3456     PartialTensorShape input_slice_shape(input_slice_properties->shape());
3457 
3458     const OpInfo::TensorProperties* output_properties;
3459     TF_RETURN_IF_ERROR(GetTensorProperties(
3460         strings::StrCat(node->name(), ":", 0), &output_properties));
3461     PartialTensorShape output_shape(output_properties->shape());
3462     NodeDef* output =
3463         AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
3464     if (!must_expand_dims) {
3465       output->set_op("Identity");
3466       output->set_device(node->device());
3467       SetDataTypeToAttr(output_properties->dtype(), "T", output);
3468       output->add_input(input_slice);
3469     } else {
3470       NodeDef* axis = AddEmptyNode(
3471           OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
3472       axis->set_op("Const");
3473       axis->set_device(node->device());
3474       // We need to add a control edge from input slice to guarantee that axis
3475       // constant will be executed in the same frame as `input_slice`, otherwise
3476       // ExpandDims might have mismatched input frames.
3477       axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
3478       auto axis_attr = axis->mutable_attr();
3479       SetDataTypeToAttr(DT_INT32, "dtype", axis);
3480       auto* axis_t = (*axis_attr)["value"].mutable_tensor();
3481       axis_t->set_dtype(DT_INT32);
3482       axis_t->add_int_val(pack_axis);
3483       AddToOptimizationQueue(axis);
3484       output->set_op("ExpandDims");
3485       output->set_device(node->device());
3486       SetDataTypeToAttr(output_properties->dtype(), "T", output);
3487       SetDataTypeToAttr(DT_INT32, "Tdim", output);
3488       output->add_input(input_slice);
3489       output->add_input(axis->name());
3490     }
3491 
3492     // Copy dependencies over.
3493     ForwardControlDependencies(output, {node, pack});
3494     AddToOptimizationQueue(output);
3495     *simplified_node_name = output->name();
3496 
3497     return Status::OK();
3498   }
3499 };
3500 
3501 // Eliminates unnecessary copies during sparse embedding lookup operations.
3502 //
3503 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
3504 // generates code of the form:
3505 //
3506 //     embeddings = <a 2D Tensor>
3507 //     sparse_ids = <a tf.int64 SparseTensor>
3508 //     segment_ids = sparse_ids.indices[:, 0]
3509 //     ids, idx = tf.unique(sparse_ids.values)
3510 //     gathered_rows = tf.gather(params, ids)
3511 //     result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
3512 //
3513 // In this case, all of the work in `tf.unique()` and `tf.gather()`
3514 // can be avoided by passing the full embeddings to
3515 // `tf.sparse.segment_<combiner>()` and performing the same amount of
3516 // computation (but fewer copies and allocations) as follows:
3517 //
3518 //     embeddings = <a 2D Tensor>
3519 //     sparse_ids = <a tf.int64 SparseTensor>
3520 //     segment_ids = sparse_ids.indices[:, 0]
3521 //     result = tf.sparse.segment_<combiner>(
3522 //          embeddings, sparse_ids.values, segment_ids)
3523 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
3524  public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3525   explicit SimplifyEmbeddingLookupStage(
3526       const GraphOptimizerContext& ctx,
3527       const ArithmeticOptimizerContext& ctx_ext)
3528       : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
3529   }
3530   ~SimplifyEmbeddingLookupStage() override = default;
3531 
IsSupported(const NodeDef * node) const3532   bool IsSupported(const NodeDef* node) const override {
3533     return IsAnySparseSegmentReduction(*node);
3534   }
3535 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3536   Status TrySimplify(NodeDef* reduction_node,
3537                      string* simplified_node_name) override {
3538     if (IsInPreserveSet(*reduction_node)) return Status::OK();
3539 
3540     // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
3541     // axis.
3542     NodeDef* gather_node = nullptr;
3543     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
3544     if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
3545         gather_node->device() != reduction_node->device())
3546       return Status::OK();
3547     if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
3548       return Status::OK();
3549 
3550     // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
3551     // axis.
3552     NodeDef* unique_node = nullptr;
3553     TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
3554     if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
3555         unique_node->device() != gather_node->device())
3556       return Status::OK();
3557     if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
3558       return Status::OK();
3559 
3560     DataType unique_element_type;
3561     TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
3562 
3563     // Input 1 (indices) of the reduction node must be output 1 of the unique
3564     // node.
3565     const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
3566     if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
3567 
3568     // Input 0 (data) of the reduction node becomes input 1 (params) of the
3569     // gather node.
3570     reduction_node->set_input(0, gather_node->input(0));
3571     ctx().node_map->UpdateInput(reduction_node->name(),
3572                                 reduction_node->input(0),
3573                                 gather_node->input(0));
3574 
3575     // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
3576     // node.
3577     reduction_node->set_input(1, unique_node->input(0));
3578     ctx().node_map->UpdateInput(reduction_node->name(),
3579                                 reduction_node->input(1),
3580                                 unique_node->input(0));
3581     SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
3582 
3583     *simplified_node_name = reduction_node->name();
3584     return Status::OK();
3585   }
3586 
3587  private:
IsAxis0(const NodeDef & node,int axis_input)3588   bool IsAxis0(const NodeDef& node, int axis_input) {
3589     Tensor axis_tensor;
3590     if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
3591       return false;
3592     if (axis_tensor.NumElements() != 1) return false;
3593     if (axis_tensor.dtype() == DT_INT32) {
3594       return axis_tensor.flat<int32>()(0) == 0;
3595     } else if (axis_tensor.dtype() == DT_INT64) {
3596       return axis_tensor.flat<int64>()(0) == 0;
3597     } else {
3598       return false;
3599     }
3600   }
3601 };
3602 
3603 // Eliminates unnecessary casts before sparse segment reduction operations.
3604 //
3605 // Existing graphs and library code would often insert a cast from DT_INT64 to
3606 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
3607 // Support for for DT_INT64 indices and/or segment_ids now exists, so we can
3608 // pass the input directly without a cast.
3609 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
3610  public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3611   explicit RemoveCastIntoSegmentReductionStage(
3612       const GraphOptimizerContext& ctx,
3613       const ArithmeticOptimizerContext& ctx_ext)
3614       : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
3615                                  ctx_ext) {}
3616   ~RemoveCastIntoSegmentReductionStage() override = default;
3617 
IsSupported(const NodeDef * node) const3618   bool IsSupported(const NodeDef* node) const override {
3619     return IsAnySparseSegmentReduction(*node);
3620   }
3621 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3622   Status TrySimplify(NodeDef* reduction_node,
3623                      string* simplified_node_name) override {
3624     if (IsInPreserveSet(*reduction_node)) return Status::OK();
3625 
3626     bool optimized = false;
3627 
3628     // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
3629     // DT_INT64.
3630     std::array<std::pair<int, string>, 2> input_details = {
3631         std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
3632 
3633     for (const auto& input : input_details) {
3634       int input_index = input.first;
3635       const string& type_attr_name = input.second;
3636       NodeDef* cast_node = nullptr;
3637       TF_RETURN_IF_ERROR(
3638           GetInputNode(reduction_node->input(input_index), &cast_node));
3639       DataType original_index_type;
3640       if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
3641         reduction_node->set_input(input_index, cast_node->input(0));
3642         ctx().node_map->UpdateInput(reduction_node->name(),
3643                                     reduction_node->input(1),
3644                                     cast_node->input(0));
3645         SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
3646         optimized = true;
3647       }
3648     }
3649 
3650     if (optimized) *simplified_node_name = reduction_node->name();
3651     return Status::OK();
3652   }
3653 
3654  private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)3655   bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
3656     if (!IsCast(node)) return false;
3657     if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
3658     return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
3659   }
3660 };
3661 
3662 }  // namespace
3663 
SimplifyArithmeticOps(bool can_use_shapes)3664 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
3665   SetVector<NodeDef*> nodes_to_simplify;
3666   nodes_to_simplify.Reserve(optimized_graph_->node_size());
3667   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3668     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
3669   }
3670 
3671   const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
3672                                   graph_properties_.get(), node_map_.get(),
3673                                   &feed_nodes_, opt_level_);
3674   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
3675 
3676   // Stop pipeline after first stage returning non-empty simplified tensor
3677   // name.
3678   const auto stop = [](const string& result) { return !result.empty(); };
3679   GraphOptimizerStagePipeline<string> pipeline(stop);
3680 
3681   if (options_.combine_add_to_addn && can_use_shapes)
3682     pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
3683   if (options_.fold_conjugate_into_transpose)
3684     pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
3685   if (options_.fold_multiply_into_conv)
3686     pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
3687   if (options_.fold_transpose_into_matmul)
3688     pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
3689   if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
3690     pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
3691   if (options_.minimize_broadcasts && can_use_shapes)
3692     pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
3693   if (options_.remove_identity_transpose && can_use_shapes)
3694     pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
3695   if (options_.remove_involution)
3696     pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
3697   if (options_.remove_redundant_bitcast)
3698     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
3699   if (options_.remove_redundant_cast)
3700     pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
3701   if (options_.remove_redundant_reshape)
3702     pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
3703   if (options_.remove_negation)
3704     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
3705   if (options_.replace_mul_with_square)
3706     pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
3707   if (options_.remove_logical_not)
3708     pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
3709   if (options_.reorder_cast_like_and_value_preserving)
3710     pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
3711   if (options_.simplify_aggregation)
3712     pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
3713   if (options_.hoist_cwise_unary_chains)
3714     pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
3715   if (options_.convert_sqrt_div_to_rsqrt_mul)
3716     pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
3717   if (options_.remove_idempotent)
3718     pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
3719   if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
3720   if (options_.convert_log1p)
3721     pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
3722   if (options_.convert_log_softmax)
3723     pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
3724   if (options_.optimize_max_or_min_of_monotonic)
3725     pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
3726   if (options_.convert_expm1)
3727     pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
3728   if (options_.unary_ops_composition)
3729     pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
3730   if (options_.remove_stack_slice_same_axis)
3731     pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
3732   if (options_.simplify_embedding_lookup)
3733     pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
3734   if (options_.remove_cast_into_segment_reduction)
3735     pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
3736   if (options_.fuse_squared_diff)
3737     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
3738 
3739   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
3740           << absl::StrJoin(pipeline.StageNames(), ", ");
3741 
3742   while (!nodes_to_simplify.Empty()) {
3743     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3744     NodeDef* node = nodes_to_simplify.PopBack();
3745 
3746     string simplified_tensor = "";
3747     bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
3748 
3749     // If the node was not optimized by any of the stages, go to the next one.
3750     if (!optimized) continue;
3751 
3752     // re-wire consumers of an old node to the new one
3753     if (NodeName(simplified_tensor) != node->name()) {
3754       // Always consider simplified_tensor for further optimizations.
3755       NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
3756       if (simplified_node != nullptr) {
3757         nodes_to_simplify.PushBack(simplified_node);
3758       }
3759       // When `node` is simplified to another node rather than in-place, the
3760       // consumers of `node` are already redirected to `simplified_tensor`.
3761       // Re-push the consumers into `nodes_to_simplify` for further
3762       // optimizations.
3763       const std::vector<NodeDef*> consumers =
3764           node_map_->GetOutputsOrderedByNodeName(node->name());
3765       for (NodeDef* consumer : consumers) {
3766         // Update `consumer`'s use of `node` to `input`'s operand.
3767         for (int i = 0; i < consumer->input_size(); ++i) {
3768           int operand_pos;
3769           string operand_node_name =
3770               ParseNodeName(consumer->input(i), &operand_pos);
3771           if (operand_node_name == node->name()) {
3772             *consumer->mutable_input(i) =
3773                 (operand_pos < 0
3774                      ? AsControlDependency(NodeName(simplified_tensor))
3775                      : simplified_tensor);
3776           }
3777         }
3778         node_map_->UpdateInput(consumer->name(), node->name(),
3779                                simplified_tensor);
3780         nodes_to_simplify.PushBack(consumer);
3781       }
3782     }
3783   }
3784   return Status::OK();
3785 }
3786 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)3787 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
3788                                      const GrapplerItem& item,
3789                                      GraphDef* optimized_graph) {
3790   // Set up helper data structures.
3791   nodes_to_preserve_ = item.NodesToPreserve();
3792   fetch_nodes_known_ = !item.fetch.empty();
3793   GrapplerItem optimized_item(item);
3794   optimized_graph_ = &optimized_item.graph;
3795 
3796   node_map_.reset(new NodeMap(optimized_graph_));
3797   for (const auto& feed : item.feed) {
3798     feed_nodes_.insert(NodeName(feed.first));
3799   }
3800 
3801   // // Disable restricted graph rewrites.
3802   options_.unary_ops_composition &=
3803       item.optimization_options().allow_non_differentiable_rewrites;
3804 
3805   // Perform topological sort on the graph in order to help DedupComputations
3806   // and AddOpsRewrite to optimize larger subgraphs starting from the roots
3807   // with more inputs.
3808   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
3809   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3810 
3811   graph_properties_.reset(new GraphProperties(optimized_item));
3812   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3813   const Status status =
3814       graph_properties_->InferStatically(assume_valid_feeds,
3815                                          /*aggressive_shape_inference=*/false,
3816                                          /*include_tensor_values=*/false);
3817   const bool can_use_shapes = status.ok();
3818   if (!can_use_shapes) {
3819     VLOG(1) << "Shape inference failed." << status.error_message();
3820   }
3821 
3822   // Perform the optimizations.
3823   TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
3824 
3825   optimized_graph->Swap(optimized_graph_);
3826   return Status::OK();
3827 }
3828 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)3829 void ArithmeticOptimizer::Feedback(Cluster* /*cluster*/,
3830                                    const GrapplerItem& /*item*/,
3831                                    const GraphDef& /*optimized_graph*/,
3832                                    double /*result*/) {
3833   // Nothing to do for ArithmeticOptimizer.
3834 }
3835 
3836 }  // namespace grappler
3837 }  // namespace tensorflow
3838