1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/attr_value_util.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/grappler/costs/graph_properties.h"
37 #include "tensorflow/core/grappler/graph_topology_view.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
41 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
42 #include "tensorflow/core/grappler/utils.h"
43 #include "tensorflow/core/grappler/utils/canonicalizer.h"
44 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
45 #include "tensorflow/core/grappler/utils/topological_sort.h"
46 #include "tensorflow/core/grappler/utils/traversal.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/hash/hash.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/platform/errors.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/saved_tensor_slice_util.h"
56 #include "tensorflow/core/util/strided_slice_op.h"
57
58 using tensorflow::strings::StrCat;
59
60 namespace tensorflow {
61 namespace grappler {
62 namespace {
63
64 // Mark nodes created or optimized by a stage with a tag.
65 constexpr char kAddOpsRewriteTag[] =
66 "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
67 constexpr char kMinimizeBroadcastsTag[] =
68 "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
69
70 // Extract values from a Const op to `values`. Returns true if succeeds.
71 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)72 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
73 if (node.op() != "Const") {
74 return false;
75 }
76
77 if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
78 node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
79 return false;
80 }
81
82 // TensorProto represents the content of the tensor in either <type>_val or
83 // tensor_content.
84 const TensorProto& tensor = node.attr().at("value").tensor();
85 typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
86 checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
87
88 if (!tensor_values->empty() && tensor.has_tensor_shape()) {
89 // When tensor_shape is set, theoretically the representation of the data
90 // could be compressed. So, before copying values to the returned vector,
91 // make sure no compression happens.
92 const TensorShapeProto& shape = tensor.tensor_shape();
93 if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
94 values->insert(values->end(), tensor_values->begin(),
95 tensor_values->end());
96 return true;
97 }
98 }
99
100 const auto tensor_content_size = tensor.tensor_content().size();
101 if (tensor_content_size > 0) {
102 CHECK_EQ(0, tensor_content_size % sizeof(T))
103 << "tensor_content_size (" << tensor_content_size
104 << ") is not a multiple of " << sizeof(T);
105 values->resize(tensor_content_size / sizeof(T));
106 port::CopyToArray(tensor.tensor_content(),
107 reinterpret_cast<char*>(values->data()));
108 return true;
109 }
110
111 return false;
112 }
113
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)114 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
115 GraphDef* graph, NodeMap* node_map) {
116 bool already_exists = false;
117 for (const string& input : node->input()) {
118 if (input == new_input || AsControlDependency(input) == new_input) {
119 already_exists = true;
120 break;
121 }
122 }
123 if (!already_exists) {
124 const string ctrl_dep =
125 ConstantFolding::AddControlDependency(new_input, graph, node_map);
126 node->add_input(ctrl_dep);
127 node_map->AddOutput(NodeName(new_input), node->name());
128 }
129 return !already_exists;
130 }
131
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)132 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
133 (*node->mutable_attr())[attr_name].set_type(dtype);
134 }
135
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)136 NodeDef* GetTailOfValuePreservingChain(
137 const NodeDef& node, const NodeMap& node_map,
138 const std::unordered_set<string>& nodes_to_preserve) {
139 auto is_value_preserving_non_branching = [&](const NodeDef& node) {
140 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
141 IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
142 };
143 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
144 is_value_preserving_non_branching);
145 }
146
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)147 NodeDef* GetTailOfIdempotentChain(
148 const NodeDef& node, const NodeMap& node_map,
149 const std::unordered_set<string>& nodes_to_preserve) {
150 auto is_idempotent_non_branching = [&](const NodeDef& node) {
151 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
152 IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
153 };
154 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
155 is_idempotent_non_branching);
156 }
157
158 // GetElementUnexhaustive tries to get the value of an element in a tensor and
159 // turn it into complex128 type. It only check for a limited number of data
160 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)161 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
162 complex128* element) {
163 if (dtypes.find(t.dtype()) == dtypes.end()) return false;
164 switch (t.dtype()) {
165 case DT_BFLOAT16:
166 *element = complex128(t.flat<bfloat16>()(i));
167 return true;
168 case DT_HALF:
169 *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
170 return true;
171 case DT_INT32:
172 *element = complex128(t.flat<int32>()(i));
173 return true;
174 case DT_INT64:
175 *element = complex128(t.flat<int64>()(i));
176 return true;
177 case DT_FLOAT:
178 *element = complex128(t.flat<float>()(i));
179 return true;
180 case DT_DOUBLE:
181 *element = complex128(t.flat<double>()(i));
182 return true;
183 case DT_COMPLEX64:
184 *element = complex128(t.flat<complex64>()(i));
185 return true;
186 case DT_COMPLEX128:
187 *element = t.flat<complex128>()(i);
188 return true;
189 default:
190 return false;
191 }
192 }
193
NodeIsOnCpu(const NodeDef & node)194 bool NodeIsOnCpu(const NodeDef& node) {
195 string task;
196 string device;
197 return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
198 absl::StrContains(device, DEVICE_CPU);
199 }
200
201 // Graph optimizer context extension specific to ArithmeticOptimizer.
202 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon327bfa1e0111::ArithmeticOptimizerContext203 explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
204 : nodes_to_simplify(nodes_to_simplify) {}
205 SetVector<NodeDef*>* nodes_to_simplify;
206 };
207
208 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
209 // AddOps optimization, etc...
210 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
211 public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)212 explicit ArithmeticOptimizerStage(const string& name,
213 const GraphOptimizerContext& ctx,
214 const ArithmeticOptimizerContext ctx_ext)
215 : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
216 ctx_ext_(ctx_ext) {}
217 ~ArithmeticOptimizerStage() override = default;
218
219 protected:
220 // Simplification graph rewrite can create additional nodes that are inputs
221 // to final simplified node, they can be also added to the arithmetic
222 // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)223 void AddToOptimizationQueue(NodeDef* node) {
224 ctx_ext_.nodes_to_simplify->PushBack(node);
225 }
226
227 // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
228 // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)229 void ForwardControlDependencies(
230 NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
231 for (const auto& src : src_nodes) {
232 for (int i = src->input_size() - 1; i >= 0; --i) {
233 if (IsControlInput(src->input(i))) {
234 *target_node->add_input() = src->input(i);
235 ctx().node_map->AddOutput(NodeName(src->input(i)),
236 target_node->name());
237 } else {
238 break;
239 }
240 }
241 }
242 DedupControlInputs(target_node);
243 }
244
IsReallyConstant(const NodeDef & node) const245 bool IsReallyConstant(const NodeDef& node) const {
246 if (!IsConstant(node)) {
247 return false;
248 }
249 // If the node is fed it's not constant anymore.
250 return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
251 }
252
IsInPreserveSet(const NodeDef & node) const253 bool IsInPreserveSet(const NodeDef& node) const {
254 return ctx().nodes_to_preserve->find(node.name()) !=
255 ctx().nodes_to_preserve->end();
256 }
257
258 // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const259 bool IsDrivenByControlDependency(const NodeDef& node) const {
260 return std::any_of(
261 node.input().begin(), node.input().end(),
262 [](const string& input) { return IsControlInput(input); });
263 }
264
265 // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const266 bool DrivesControlDependency(const NodeDef& node) const {
267 for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
268 for (int i = 0; i < output->input_size(); ++i) {
269 const TensorId tensor = ParseTensorName(output->input(i));
270 if (tensor.node() == node.name() && tensor.index() < 0) {
271 return true;
272 }
273 }
274 }
275 return false;
276 }
277
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)278 bool GetTensorFromConstNode(const string& node_name_or_input,
279 Tensor* tensor) {
280 const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
281 return node != nullptr && IsReallyConstant(*node) &&
282 CheckAttrExists(*node, "value").ok() &&
283 tensor->FromProto(node->attr().at("value").tensor());
284 }
285
286 private:
287 // Extended context required for ArithmeticOptimizer.
288 const ArithmeticOptimizerContext ctx_ext_;
289 };
290
291 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
292 // group of nodes from the optimized graph.
293 //
294 // * AddOpsRewrite:
295 // Rewrite a group of Add/AddN with compact Add/AddN tree
296 //
297 // * MinimizeBroadcasts:
298 // Rewrite a group of binary associative ops, reordering
299 // inputs, to minimize the cost of broadcast
300 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
301 public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)302 explicit ArithmeticNodesGroupOptimizerStage(
303 const string& name, const GraphOptimizerContext& ctx,
304 const ArithmeticOptimizerContext ctx_ext)
305 : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
306 ~ArithmeticNodesGroupOptimizerStage() override = default;
307
308 // Input name with a statically inferred shape from GraphProperties
309 struct InputAndShape {
InputAndShapetensorflow::grappler::__anon327bfa1e0111::ArithmeticNodesGroupOptimizerStage::InputAndShape310 InputAndShape(const string& input, const TensorShapeProto& shape)
311 : input(input), shape(shape) {}
312 string input;
313 TensorShapeProto shape;
314 };
315
316 // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
317 // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
318 // obtained by graph traversal, starting from a root node.
319 struct OptimizedNodesGroup {
320 NodeDef* root_node;
321 TensorShapeProto root_shape;
322 // Optimized nodes that will be updated or removed by rewrite
323 std::vector<NodeDef*> optimized_nodes;
324 // Inputs to optimized nodes
325 std::vector<InputAndShape> inputs;
326 };
327
TrySimplify(NodeDef * node,string * simplified_node_name)328 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
329 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
330
331 OptimizedNodesGroup group;
332 TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
333
334 if (!group.optimized_nodes.empty()) {
335 *simplified_node_name = RewriteOptimizedNodesGroup(group);
336 }
337
338 return Status::OK();
339 }
340
341 protected:
342 // Modify the optimized graph after nodes group was successfully identified
343 virtual string RewriteOptimizedNodesGroup(
344 const OptimizedNodesGroup& group) = 0;
345
346 // Check if input can become a part of current optimized nodes group.
347 virtual bool IsAbsorbableByOptimizedNodesGroup(
348 const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
349
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const350 Status AbsorbInputByOptimizedNodesGroup(const string& input,
351 OptimizedNodesGroup* group) const {
352 std::deque<const string*> input_tensors;
353 input_tensors.push_front(&input);
354
355 while (!input_tensors.empty()) {
356 const string* input_tensor = input_tensors.front();
357 input_tensors.pop_front();
358
359 // Get a node for the input tensor.
360 NodeDef* input_node;
361 TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
362
363 if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
364 group->optimized_nodes.push_back(input_node);
365 for (int i = input_node->input_size() - 1; i >= 0; --i) {
366 const string& absorbed_node_input = input_node->input(i);
367 // TODO(ezhulenev): support control inputs
368 if (IsControlInput(absorbed_node_input)) continue;
369 input_tensors.push_front(&absorbed_node_input);
370 }
371 } else {
372 // If input node can't be absorbed, add it to OptimizedNodesGroup input.
373 const OpInfo::TensorProperties* properties;
374 TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
375 group->inputs.emplace_back(*input_tensor, properties->shape());
376 }
377 }
378
379 return Status::OK();
380 }
381
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const382 Status CreateOptimizedNodesGroup(NodeDef* root_node,
383 OptimizedNodesGroup* group) const {
384 const OpInfo::TensorProperties* root_node_output_properties;
385 TF_RETURN_IF_ERROR(
386 GetTensorProperties(root_node->name(), &root_node_output_properties));
387
388 group->root_node = root_node;
389 group->root_shape = root_node_output_properties->shape();
390
391 group->optimized_nodes.reserve(root_node->input_size());
392 for (int i = 0; i < root_node->input_size(); ++i) {
393 const string& input_i = root_node->input(i);
394 // TODO(ezhulenev): add support for control inputs
395 if (IsControlInput(input_i)) continue;
396 TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
397 }
398
399 return Status::OK();
400 }
401
402 // Check if all inputs can be broadcasted to the same shape
403 // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const404 bool HasAllInputsBroadcastableToShape(
405 const NodeDef& node, const OpInfo::TensorProperties& properties) const {
406 auto is_broadcastable = [this, &properties](const string& input) {
407 const OpInfo::TensorProperties* input_props;
408 Status has_input_properties = GetTensorProperties(input, &input_props);
409 return has_input_properties.ok() &&
410 ShapesBroadcastable(properties, *input_props);
411 };
412 return std::all_of(node.input().begin(), node.input().end(),
413 is_broadcastable);
414 }
415
ShapeSignature(const TensorShapeProto & shape) const416 string ShapeSignature(const TensorShapeProto& shape) const {
417 string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
418 for (int i = 0; i < shape.dim_size(); ++i)
419 strings::StrAppend(&signature, ":", shape.dim(i).size());
420 return signature;
421 }
422
MarkWithTag(const StringPiece tag,NodeDef * node)423 void MarkWithTag(const StringPiece tag, NodeDef* node) {
424 AddNodeAttr(tag, true, node);
425 }
426
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const427 void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
428 const StringPiece tag) const {
429 AddNodeAttr(tag, true, group.root_node);
430 for (NodeDef* optimized_node : group.optimized_nodes) {
431 AddNodeAttr(tag, true, optimized_node);
432 }
433 }
434
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const435 bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
436 const NodeDef& node) const {
437 return group.root_node->device() == node.device();
438 }
439
IsInPreserveSet(const NodeDef & node) const440 bool IsInPreserveSet(const NodeDef& node) const {
441 return ctx().nodes_to_preserve->find(node.name()) !=
442 ctx().nodes_to_preserve->end();
443 }
444
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const445 bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
446 return HasNodeAttr(node, tag);
447 }
448
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const449 bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
450 const StringPiece tag2) const {
451 return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
452 }
453 };
454
455 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
456 // original inputs of absorbed nodes.
457 //
458 // 1) All nodes must have the same device placement.
459 //
460 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
461 // optimized to a single AddN node.
462 //
463 // AddN_1
464 // / | \
465 // Add_1 z Add_2 -> AddN(x, y, z, w, q, e)
466 // / \ / \
467 // x y w Add_3
468 // / \
469 // q e
470 //
471 // 3) If some nodes have different shape (it needs to be broadcastable to the
472 // shape of a "root), tree is optimized to AddNs for symbolically equal
473 // shapes, and a tree of Add ops, that minimize broadcasts.
474 //
475 // AddN_1 Add
476 // / | \ / \
477 // Add_1 z Add_2 -> Add w
478 // / \ / \ / \
479 // x y w Add_3 AddN(x, y, q, e) z
480 // / \
481 // q e
482 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
483 public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)484 explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
485 const ArithmeticOptimizerContext& ctx_ext)
486 : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
487 ~AddOpsRewriteStage() override = default;
488
489 // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const490 bool IsSupported(const NodeDef* node) const override {
491 if (!CanOptimize(*node)) return false;
492
493 // shape must be symbolically defined and all inputs compatible with it
494 const OpInfo::TensorProperties* properties;
495 Status has_properties = GetTensorProperties(node->name(), &properties);
496 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
497 HasAllInputsBroadcastableToShape(*node, *properties);
498 }
499
500 protected:
501 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const502 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
503 const NodeDef& node) const override {
504 if (!CanOptimize(node)) return false;
505
506 if (!IsOnTheSameDevice(group, node)) {
507 return false;
508 }
509 // with a single output data consumer (presumably if we reach this node from
510 // previously absorbed or a root node, it means that this node is not used
511 // as an input to any other op, outside of the group)
512 if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
513 return false;
514 }
515 // All input shapes must be broadcastable to the node shape
516 const OpInfo::TensorProperties* properties;
517 Status has_properties = GetTensorProperties(node.name(), &properties);
518 return has_properties.ok() &&
519 HasAllInputsBroadcastableToShape(node, *properties);
520 }
521
522 // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const523 bool CanOptimize(const NodeDef& node) const {
524 // TODO(ezhulenev): check if AccumulateNV2 can be supported too
525 if (!IsAdd(node) && !IsAddN(node)) {
526 return false;
527 }
528 if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
529 return false;
530 }
531 // TODO(ezhulenev): relax this condition for root node
532 return !(IsDrivenByControlDependency(node) ||
533 DrivesControlDependency(node));
534 }
535
536 // Rewrite a group of add ops into a single AddN if all input shapes are
537 // symbolically equal. If not, create AddN for equal shapes first, and then
538 // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)539 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
540 VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
541 << " op=" << group.root_node->op()
542 << " num_optimized_nodes=" << group.optimized_nodes.size()
543 << " num_inputs=" << group.inputs.size();
544
545 // Do not optimize any of the nodes that are part of this group.
546 MarkAllMembersWithTag(group, kAddOpsRewriteTag);
547
548 // All new nodes will be placed under the scope of a root node.
549 auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
550
551 // Find what shapes are present in the inputs of absorbed nodes.
552 std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
553 for (const auto& input : group.inputs) {
554 shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
555 }
556
557 using SigKV = decltype(shape_sig_to_inputs)::value_type;
558 VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
559 << " unique shapes: "
560 << absl::StrJoin(shape_sig_to_inputs, ", ",
561 [](string* out, SigKV p) {
562 strings::StrAppend(out, p.first);
563 });
564
565 // Collect all the shapes from representative elements.
566 std::vector<TensorShapeProto> shapes;
567 shapes.reserve(shape_sig_to_inputs.size());
568 for (const auto& el : shape_sig_to_inputs)
569 shapes.push_back(el.second[0].shape);
570
571 // If all inputs have the same shape, rewrite whole group with a single AddN
572 if (shapes.size() == 1) {
573 string node_name = UniqueOptimizedNodeName(root_scope_and_name);
574 AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
575 group.inputs);
576 return node_name;
577 }
578
579 // For inputs of different shapes:
580 // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
581 // 2. Build a tree of Add nodes, minimizing cost of broadcast
582 std::sort(shapes.begin(), shapes.end(),
583 [](const TensorShapeProto& left, const TensorShapeProto& right) {
584 return CompareSymbolicallyShapedTensorSizes(left, right);
585 });
586
587 // optimized name for leaf AddN nodes
588 auto leaf_node_name = [&root_scope_and_name, this](int i) {
589 return UniqueOptimizedNodeName(root_scope_and_name,
590 strings::StrCat("Leaf_", i));
591 };
592 // optimized name for internal nodes of a tree built up from AddN leaves
593 auto internal_node_name = [&root_scope_and_name, this](int i) {
594 return UniqueOptimizedNodeName(root_scope_and_name,
595 strings::StrCat("Internal_", i));
596 };
597
598 // Add/AddN nodes that must be added to the tree
599 std::deque<InputAndShape> add_ops;
600
601 // Prepare leaf AddN nodes for inputs of equal shape
602 for (int i = 0, end = shapes.size(); i < end; ++i) {
603 const auto node_name = leaf_node_name(i);
604 const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
605 add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
606 node_name, inputs));
607 }
608
609 // Build up a tree of Add ops
610 int internal_nodes = 0;
611 do {
612 const InputAndShape lhs = add_ops.front();
613 add_ops.pop_front();
614 const InputAndShape rhs = add_ops.front();
615 add_ops.pop_front();
616 string name = add_ops.empty()
617 ? UniqueOptimizedNodeName(root_scope_and_name)
618 : internal_node_name(internal_nodes++);
619 InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
620 add_ops.push_front(add);
621 } while (add_ops.size() > 1);
622
623 InputAndShape optimized_root_node = add_ops.front();
624 return optimized_root_node.input;
625 }
626
627 // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)628 InputAndShape AddInputsOfSymbolicallyEqualShape(
629 const NodeDef& root_node, const string& node_name,
630 const std::vector<InputAndShape>& inputs) {
631 CHECK(!inputs.empty()) << "Inputs must be non-empty";
632
633 // Do not create redundant AddN nodes
634 if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
635 return inputs[0];
636 }
637
638 // get shape from representative element
639 auto shape = inputs[0].shape;
640
641 // copy attributes from a root node
642 DataType dtype = root_node.attr().at("T").type();
643
644 // add new AddN node
645 NodeDef* node = AddEmptyNode(node_name);
646 node->set_op("AddN");
647 node->set_device(root_node.device());
648 (*node->mutable_attr())["T"].set_type(dtype);
649 (*node->mutable_attr())["N"].set_i(inputs.size());
650
651 for (const auto& inputAndShape : inputs) {
652 ctx().node_map->AddOutput(inputAndShape.input, node_name);
653 node->add_input(inputAndShape.input);
654 }
655
656 MarkWithTag(kAddOpsRewriteTag, node);
657 return InputAndShape(node_name, shape);
658 }
659
660 // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)661 InputAndShape AddAggregatedInputs(const NodeDef& root_node,
662 const string& node_name,
663 const InputAndShape& left,
664 const InputAndShape& right) {
665 // copy attributes from a root node
666 DataType dtype = root_node.attr().at("T").type();
667
668 // add new Add node
669 NodeDef* node = AddEmptyNode(node_name);
670 node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
671 : "AddV2");
672 node->set_device(root_node.device());
673 (*node->mutable_attr())["T"].set_type(dtype);
674 node->add_input(left.input);
675 node->add_input(right.input);
676
677 ctx().node_map->AddOutput(left.input, node_name);
678 ctx().node_map->AddOutput(right.input, node_name);
679
680 MarkWithTag(kAddOpsRewriteTag, node);
681 return InputAndShape(
682 node_name, TensorShapeProto()); // shape is not important at this point
683 }
684 };
685
686 // Use the distributive property of multiplication and division over addition,
687 // along with commutativity of the former, to hoist common factors/denominators
688 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
689 // This pattern occurs frequently in regularization terms for the gradients
690 // during training.
691 //
692 // For example, we can rewrite an expression of the form:
693 // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
694 // to the following:
695 // Mul(x, AddN(y1, y2, y3, ... yn))
696 // For division, we can rewrite
697 // AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
698 // to:
699 // Div(AddN(y1, y2, y3, ... yn), x)
700 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
701 public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)702 explicit HoistCommonFactorOutOfAggregation(
703 const GraphOptimizerContext& ctx,
704 const ArithmeticOptimizerContext& ctx_ext)
705 : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
706 ~HoistCommonFactorOutOfAggregation() override = default;
707
IsSupported(const NodeDef * node) const708 bool IsSupported(const NodeDef* node) const override {
709 return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
710 !IsRewritten(node);
711 }
712
TrySimplify(NodeDef * node,string * simplified_node_name)713 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
714 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
715
716 bool common_factor_is_denominator = false;
717 std::set<string> common_factors;
718 std::vector<string> ctrl_deps;
719 TF_RETURN_IF_ERROR(GetCommonFactors(
720 node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
721
722 if (common_factors.size() == 1) {
723 const string& common_factor = *common_factors.begin();
724
725 // Gather up the non-shared factors
726 bool shapes_match = true;
727 std::vector<string> unique_factors;
728 TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
729 common_factor_is_denominator,
730 &shapes_match, &unique_factors));
731
732 if (shapes_match) {
733 NodeDef* input_0;
734 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
735
736 // Use a copy of the first node for the outer multiplication/division.
737 NodeDef* new_outer_node = AddCopyNode(
738 OuterNodeName(node, common_factor_is_denominator), input_0);
739 // And a copy of aggregation node as one of the inner operands
740 NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
741
742 new_outer_node->set_device(node->device());
743 if (common_factor_is_denominator) {
744 new_outer_node->set_input(0, new_add_node->name());
745 new_outer_node->set_input(1, common_factor);
746 } else {
747 new_outer_node->set_input(0, common_factor);
748 new_outer_node->set_input(1, new_add_node->name());
749 }
750
751 ctx().node_map->AddOutput(common_factor, new_outer_node->name());
752 ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
753
754 // Hoist non-shared factors up into the new AddN node.
755 for (int i = 0, end = unique_factors.size(); i < end; ++i) {
756 const string& unique_factor_i = unique_factors[i];
757 new_add_node->set_input(i, unique_factor_i);
758 ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
759 }
760
761 // Add control deps on add node
762 for (const string& ctrl_dep : ctrl_deps) {
763 *new_add_node->add_input() = ctrl_dep;
764 ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
765 }
766
767 // optimize new inner aggregation node
768 AddToOptimizationQueue(new_add_node);
769 // do not optimize the same node twice
770 rewritten_nodes_.insert(node->name());
771 *simplified_node_name = new_outer_node->name();
772 }
773 }
774 return Status::OK();
775 }
776
777 private:
778 // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const779 string OuterNodeName(const NodeDef* node, bool is_div) const {
780 auto scope_and_name = ParseNodeScopeAndName(node->name());
781 return is_div ? OptimizedNodeName(scope_and_name, "Div")
782 : OptimizedNodeName(scope_and_name, "Mul");
783 }
784
785 // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const786 string InnerAddNodeName(const NodeDef* node) const {
787 auto scope_and_name = ParseNodeScopeAndName(node->name());
788 return OptimizedNodeName(scope_and_name, "AddV2");
789 }
790
791 // Determine the set of common factors if the input nodes are all Mul or
792 // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const793 Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
794 bool* common_factor_is_denominator,
795 std::vector<string>* ctrl_deps) const {
796 CHECK(common_factors->empty());
797 CHECK_NOTNULL(common_factor_is_denominator);
798 *common_factor_is_denominator = false;
799
800 bool has_mul = false;
801 bool has_div = false;
802 for (int i = 0; i < node->input_size(); ++i) {
803 if (i > 0 && common_factors->empty()) break;
804 if (IsControlInput(node->input(i))) {
805 ctrl_deps->push_back(node->input(i));
806 continue;
807 }
808 NodeDef* input;
809 TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
810
811 if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
812 (IsAnyDiv(*input) && has_mul)) {
813 // Break if input is neither a Mul or Div, or if there are both Mul &
814 // Div Ops.
815 common_factors->clear();
816 break;
817 } else if (IsAnyDiv(*input)) {
818 has_div = true;
819 // In case of possible common dividers, we avoid hoisting out if any
820 // input is not float/double, since integer division is not distributive
821 // over addition.
822 const OpInfo::TensorProperties* properties0;
823 const OpInfo::TensorProperties* properties1;
824 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
825 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
826 if (properties0->dtype() != DT_FLOAT &&
827 properties0->dtype() != DT_DOUBLE &&
828 properties1->dtype() != DT_FLOAT &&
829 properties1->dtype() != DT_DOUBLE) {
830 common_factors->clear();
831 break;
832 }
833 } else if (IsMul(*input)) {
834 has_mul = true;
835 }
836
837 // We only focus on common factors from denominators if any Op is a
838 // Div.
839 std::set<string> factors_i =
840 has_mul ? std::set<string>{input->input(0), input->input(1)}
841 : std::set<string>{input->input(1)};
842 if (i == 0) {
843 std::swap(*common_factors, factors_i);
844 } else {
845 std::set<string> intersection;
846 std::set_intersection(
847 factors_i.begin(), factors_i.end(), common_factors->begin(),
848 common_factors->end(),
849 std::inserter(intersection, intersection.begin()));
850 std::swap(*common_factors, intersection);
851 }
852 for (int i = 2; i < input->input_size(); ++i) {
853 ctrl_deps->push_back(input->input(i));
854 }
855 }
856
857 *common_factor_is_denominator = has_div;
858 return Status::OK();
859 }
860
861 // Gather up the non-shared factors (the y's in the example).
862 // Unless the aggregation is Add, we have to make sure that all the y's
863 // have the same shape since the other aggregation ops do not support
864 // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const865 Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
866 const bool common_factor_is_denominator,
867 bool* shapes_match,
868 std::vector<string>* unique_factors) const {
869 *shapes_match = true;
870 unique_factors->reserve(node->input_size());
871
872 for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
873 const string& input = node->input(i);
874 if (IsControlInput(input)) {
875 break;
876 }
877 NodeDef* inner_node;
878 TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
879 const int unique_factor_index =
880 common_factor_is_denominator
881 ? 0
882 : (inner_node->input(0) == common_factor ? 1 : 0);
883 unique_factors->push_back(inner_node->input(unique_factor_index));
884 if (i > 0 && !IsAdd(*node)) {
885 const OpInfo::TensorProperties* lhs;
886 const OpInfo::TensorProperties* rhs;
887 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
888 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
889 *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
890 }
891 }
892 return Status::OK();
893 }
894
IsRewritten(const NodeDef * node) const895 bool IsRewritten(const NodeDef* node) const {
896 // if graph rewrite happens in multiple passes without graph pruning between
897 // them, it's possible that rewritten node already exists in a graph
898 return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
899 ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
900 ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
901 ctx().node_map->NodeExists(InnerAddNodeName(node));
902 }
903
904 // keep names of the nodes that were optimized by this stage
905 std::unordered_set<string> rewritten_nodes_;
906 };
907
908 // Binary associative ops can be re-ordered to minimize the number of broadcasts
909 // and the size of a temporary tensors.
910 //
911 // Example: [a, c] - scalars, [b, d] - matrices
912 // @ - binary associative op (Add or Mul)
913 // @* - broadcast
914 //
915 // @ @*
916 // / \ / \
917 // @* @* -> @ @
918 // / \ / \ / \ / \
919 // a b c d a c b d
920 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
921 public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)922 explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
923 const ArithmeticOptimizerContext& ctx_ext)
924 : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
925 }
926 ~MinimizeBroadcasts() override = default;
927
IsSupported(const NodeDef * node) const928 bool IsSupported(const NodeDef* node) const override {
929 if (!IsBinaryAssociative(*node)) return false;
930
931 if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
932 return false;
933
934 // has a symbolically defined shape with broadcastable inputs
935 const OpInfo::TensorProperties* properties;
936 Status has_properties = GetTensorProperties(node->name(), &properties);
937 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
938 HasAllInputsBroadcastableToShape(*node, *properties);
939 }
940
941 protected:
IsBinaryAssociative(const NodeDef & node) const942 bool IsBinaryAssociative(const NodeDef& node) const {
943 return IsMul(node) || IsAdd(node);
944 }
945
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const946 bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
947 return group.root_node->op() == node.op();
948 }
949
950 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const951 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
952 const NodeDef& node) const override {
953 if (!IsSameOp(group, node)) {
954 return false;
955 }
956 if (IsInPreserveSet(node)) {
957 return false;
958 }
959 // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
960 if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
961 return false;
962 }
963 if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
964 return false;
965 }
966 if (!IsOnTheSameDevice(group, node)) {
967 return false;
968 }
969 // Optimized nodes updated in place, and that would break the graph, if the
970 // node has multiple output consumers
971 if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
972 return false;
973 }
974 // All input shapes must be broadcastable to the node shape
975 const OpInfo::TensorProperties* properties;
976 Status has_properties = GetTensorProperties(node.name(), &properties);
977 return has_properties.ok() &&
978 HasAllInputsBroadcastableToShape(node, *properties);
979 }
980
CountUniqueShapes(const std::vector<InputAndShape> & inputs)981 std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
982 std::set<string> sigs;
983 for (const auto& ias : inputs) {
984 sigs.insert(ShapeSignature(ias.shape));
985 }
986 return sigs.size();
987 }
988
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)989 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
990 VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
991 << " op=" << group.root_node->op()
992 << " num_optimized_nodes=" << group.optimized_nodes.size();
993
994 // Do not optimize any of the nodes that are part of this group.
995 MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
996
997 if (CountUniqueShapes(group.inputs) <= 1) {
998 VLOG(3) << "Skip min-bcast group with single unique shape";
999 // nothing to optimize when all shapes are the same
1000 return group.root_node->name();
1001 }
1002
1003 auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1004 auto num_inputs = group.inputs.size();
1005 CHECK_EQ(num_nodes, num_inputs - 1)
1006 << "Can't build a tree with " << num_inputs << " inputs, using "
1007 << num_nodes << "binary op nodes.";
1008
1009 std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1010 std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1011 group.optimized_nodes.end());
1012
1013 // sort inputs by it's shape from smallest to largest
1014 std::stable_sort(add_ops.begin(), add_ops.end(),
1015 [](const InputAndShape& lhs, const InputAndShape& rhs) {
1016 return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1017 rhs.shape);
1018 });
1019
1020 // If there is an odd number of inputs, last one is the largest, and we want
1021 // to attach it to the root node, to build a well balanced tree.
1022 std::deque<InputAndShape> add_ops_leftover;
1023 if (add_ops.size() % 2 != 0) {
1024 add_ops_leftover.push_back(add_ops.back());
1025 add_ops.pop_back();
1026 }
1027
1028 // At this point it's guaranteed that add_ops have even number of inputs.
1029 do {
1030 const InputAndShape lhs = add_ops.front();
1031 add_ops.pop_front();
1032 const InputAndShape rhs = add_ops.front();
1033 add_ops.pop_front();
1034
1035 NodeDef* node;
1036 if (!optimized_nodes.empty()) {
1037 // re-purpose optimized nodes to build a new tree
1038 node = optimized_nodes.back();
1039 optimized_nodes.pop_back();
1040 } else {
1041 // or use root node if none optimized nodes left
1042 node = group.root_node;
1043 }
1044 InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1045
1046 // Pushing updated node to the back of a deque will create a wide and
1047 // short tree, pushing to the front will create a tall tree. We prefer to
1048 // get a wide tree, it minimizes the potential number of temporary tensors
1049 // required to keep in memory, though sometimes we can go up to prevent
1050 // propagating a broadcast from leaves to the root. Example:
1051 //
1052 // inputs: [s, s, s, M] (s - scalar, M - matrix)
1053 // @* - op with broadcast
1054 //
1055 // (only push_back) @* (push_front first op)
1056 // / \
1057 // @* @ M
1058 // / \ / \
1059 // @ @* -> @ s
1060 // / \ / \ / \
1061 // s s s M s s
1062 if (add_ops.size() >= 2 &&
1063 CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1064 add_ops.at(1).shape)) {
1065 add_ops.push_front(updated_node);
1066 } else {
1067 add_ops.push_back(updated_node);
1068 }
1069 } while (add_ops.size() > 1);
1070 CHECK_EQ(1, add_ops.size());
1071
1072 // attach the largest tensor to the root op
1073 if (!add_ops_leftover.empty()) {
1074 const InputAndShape lhs = add_ops.front();
1075 add_ops.pop_front();
1076 const InputAndShape rhs = add_ops_leftover.front();
1077 InputAndShape updated_node =
1078 UpdateInputs(lhs.input, rhs.input, group.root_node);
1079 add_ops.push_back(updated_node);
1080 }
1081
1082 return add_ops.front().input;
1083 }
1084
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1085 InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1086 NodeDef* node) {
1087 string old_input_0 = node->input(0);
1088 string old_input_1 = node->input(1);
1089
1090 // Update inputs only if they changed
1091 if (old_input_0 != input_0 || old_input_1 != input_1) {
1092 node->set_input(0, input_0);
1093 node->set_input(1, input_1);
1094 // Invalidate node properties (shape)
1095 ctx().graph_properties->ClearOutputProperties(node->name());
1096 ctx().graph_properties->ClearInputProperties(node->name());
1097 // Update the node map
1098 ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1099 ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1100 ctx().node_map->AddOutput(NodeName(input_0), node->name());
1101 ctx().node_map->AddOutput(NodeName(input_1), node->name());
1102 // Add updated node to optimization queue
1103 AddToOptimizationQueue(node);
1104 }
1105
1106 TensorShapeProto shape; // shape is not important at this point
1107 return InputAndShape(node->name(), shape);
1108 }
1109 };
1110
1111 // Removes inverse transpose nodes
1112 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1113 public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1114 explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1115 const ArithmeticOptimizerContext& ctx_ext)
1116 : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1117 ~RemoveIdentityTranspose() override = default;
1118
IsSupported(const NodeDef * node) const1119 bool IsSupported(const NodeDef* node) const override {
1120 return IsTranspose(*node) || IsConjugateTranspose(*node);
1121 }
1122
TrySimplify(NodeDef * node,string * simplified_node_name)1123 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1124 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1125 NodeDef* tail = node;
1126 tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1127 *ctx().nodes_to_preserve);
1128 NodeDef* first_transpose;
1129 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1130
1131 NodeDef* node_perm;
1132 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1133 if (!IsConstant(*node_perm)) {
1134 return Status::OK();
1135 }
1136 std::vector<int64> node_perm_values;
1137 TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1138 if (first_transpose->op() == node->op()) {
1139 // Remove pairs of transposes that cancel each other.
1140 NodeDef* first_transpose_perm;
1141 TF_RETURN_IF_ERROR(
1142 GetInputNode(first_transpose->input(1), &first_transpose_perm));
1143 if (!IsConstant(*first_transpose_perm)) {
1144 return Status::OK();
1145 }
1146 std::vector<int64> first_transpose_perm_values;
1147 TF_RETURN_IF_ERROR(
1148 GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1149 if (AreInversePermutations(node_perm_values,
1150 first_transpose_perm_values)) {
1151 if (tail == node) {
1152 // Bypass adjacent pair.
1153 *simplified_node_name = first_transpose->input(0);
1154 } else {
1155 // Bypass pair connected through chain.
1156 tail->set_input(0, first_transpose->input(0));
1157 ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1158 first_transpose->input(0));
1159 ForwardControlDependencies(tail, {first_transpose});
1160 *simplified_node_name = node->input(0);
1161 }
1162 }
1163 } else {
1164 // Remove simple identity transposes.
1165 if (IsIdentityPermutation(node_perm_values)) {
1166 if (IsConjugateTranspose(*node)) {
1167 const NodeScopeAndName transpose =
1168 ParseNodeScopeAndName(node->name());
1169 const string optimized_node_name = OptimizedNodeName(transpose);
1170 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1171 new_op->set_op("Conj");
1172 new_op->mutable_input()->RemoveLast();
1173 new_op->mutable_attr()->erase("Tperm");
1174 ForwardControlDependencies(new_op, {node});
1175 *simplified_node_name = new_op->name();
1176 } else {
1177 *simplified_node_name = node->input(0);
1178 }
1179 }
1180 }
1181 return Status::OK();
1182 }
1183
1184 private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1185 Status GetPermutation(const NodeDef& node_perm,
1186 std::vector<int64>* perm64) const {
1187 std::vector<int> perm32;
1188 if (ValuesFromConstNode(node_perm, &perm32)) {
1189 perm64->reserve(perm32.size());
1190 for (int val : perm32) {
1191 perm64->push_back(static_cast<int64>(val));
1192 }
1193 return Status::OK();
1194 }
1195 if (ValuesFromConstNode(node_perm, perm64)) {
1196 return Status::OK();
1197 }
1198 return errors::InvalidArgument("Couldn't extract permutation from ",
1199 node_perm.name());
1200 }
1201
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1202 bool AreInversePermutations(const std::vector<int64>& a,
1203 const std::vector<int64>& b) {
1204 if (a.size() != b.size()) {
1205 return false;
1206 }
1207 for (int i = 0, end = a.size(); i < end; ++i) {
1208 if (a[b[i]] != i) {
1209 return false;
1210 }
1211 }
1212 return true;
1213 }
1214
IsIdentityPermutation(const std::vector<int64> & perm)1215 bool IsIdentityPermutation(const std::vector<int64>& perm) {
1216 for (int64 i = 0, end = perm.size(); i < end; ++i) {
1217 if (i != perm[i]) {
1218 return false;
1219 }
1220 }
1221 return true;
1222 }
1223 };
1224
1225 // An involution is an element-wise function f(x) that is its own inverse,
1226 // i.e. f(f(x)) = x. If we can find a chain of ops
1227 // f->op1->op2->...opn->f
1228 // where op1 through opn preserve the values of their inputs, we can remove
1229 // the two instances of the involution from the graph, since they cancel
1230 // each other.
1231 class RemoveInvolution : public ArithmeticOptimizerStage {
1232 public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1233 explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1234 const ArithmeticOptimizerContext& ctx_ext)
1235 : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1236 ~RemoveInvolution() override = default;
1237
IsSupported(const NodeDef * node) const1238 bool IsSupported(const NodeDef* node) const override {
1239 return IsInvolution(*node);
1240 }
1241
TrySimplify(NodeDef * node,string * simplified_node_name)1242 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1243 NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1244 *ctx().nodes_to_preserve);
1245
1246 NodeDef* involution;
1247 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1248
1249 if (involution->op() == node->op()) {
1250 // Skip both *node and *involution since they cancel each other.
1251 if (tail == node) {
1252 // The two nodes to eliminate are adjacent.
1253 *simplified_node_name = involution->input(0);
1254 } else {
1255 tail->set_input(0, involution->input(0));
1256 ctx().node_map->UpdateInput(tail->name(), involution->name(),
1257 involution->input(0));
1258 *simplified_node_name = node->input(0);
1259 }
1260 }
1261
1262 return Status::OK();
1263 }
1264 };
1265
1266 // Remove redundant Bitcasts.
1267 // 1) Remove Bitcast whose source type and destination type are equal
1268 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1269 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1270 public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1271 explicit RemoveRedundantBitcastStage(
1272 const GraphOptimizerContext& ctx,
1273 const ArithmeticOptimizerContext& ctx_ext)
1274 : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1275 ~RemoveRedundantBitcastStage() override = default;
1276
IsSupported(const NodeDef * node) const1277 bool IsSupported(const NodeDef* node) const override {
1278 return IsBitcast(*node);
1279 }
1280
TrySimplify(NodeDef * node,string * simplified_node_name)1281 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1282 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1283
1284 // Bypass Bitcast whose source type and destination type are equal.
1285 AttrSlice attrs(*node);
1286 DataType input_type;
1287 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1288 DataType output_type;
1289 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1290 if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1291 *simplified_node_name = node->input(0);
1292 return Status::OK();
1293 }
1294
1295 NodeDef* bitcast;
1296 TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1297 NodeDef* operand;
1298 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1299
1300 if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1301 AttrSlice operand_attrs(*operand);
1302 DataType operand_input_type;
1303 TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1304 // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1305 bitcast->set_input(0, operand->input(0));
1306 SetDataTypeToAttr(operand_input_type, "T", bitcast);
1307 ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1308 operand->input(0));
1309 AddToOptimizationQueue(bitcast);
1310 *simplified_node_name = bitcast->name();
1311 }
1312
1313 return Status::OK();
1314 }
1315 };
1316
1317 // Remove Casts whose source type and destination type are equal.
1318 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1319 public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1320 explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1321 const ArithmeticOptimizerContext& ctx_ext)
1322 : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1323 ~RemoveRedundantCastStage() override = default;
1324
IsSupported(const NodeDef * node) const1325 bool IsSupported(const NodeDef* node) const override {
1326 return IsCast(*node) && !IsInPreserveSet(*node);
1327 }
1328
TrySimplify(NodeDef * node,string * simplified_node_name)1329 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1330 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1331
1332 // Bypass Cast whose source type and destination type are equal.
1333 AttrSlice attrs(*node);
1334 DataType input_type;
1335 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1336 DataType output_type;
1337 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1338 if (input_type == output_type) {
1339 *simplified_node_name = node->input(0);
1340 }
1341 return Status::OK();
1342 }
1343 };
1344
1345 class RemoveNegationStage : public ArithmeticOptimizerStage {
1346 public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1347 explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1348 const ArithmeticOptimizerContext& ctx_ext)
1349 : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1350 ~RemoveNegationStage() override = default;
1351
IsSupported(const NodeDef * node) const1352 bool IsSupported(const NodeDef* node) const override {
1353 return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1354 }
1355
TrySimplify(NodeDef * node,string * simplified_node_name)1356 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1357 NodeDef* x;
1358 NodeDef* y;
1359 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1360 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1361 bool updated = false;
1362 if (IsNeg(*y)) {
1363 // a - (-b) = a + b or a + (-b) = a - b
1364 ForwardControlDependencies(node, {y});
1365 ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1366 node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1367 node->set_input(1, y->input(0));
1368 updated = true;
1369 } else if (IsAdd(*node) && IsNeg(*x)) {
1370 // (-a) + b = b - a
1371 ForwardControlDependencies(node, {x});
1372 ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1373 node->set_op("Sub");
1374 node->mutable_input()->SwapElements(0, 1);
1375 node->set_input(1, x->input(0));
1376 updated = true;
1377 }
1378 if (updated) {
1379 AddToOptimizationQueue(node);
1380 }
1381 return Status::OK();
1382 }
1383 };
1384
1385 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1386 public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1387 explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1388 const ArithmeticOptimizerContext& ctx_ext)
1389 : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1390 ~RemoveLogicalNotStage() override = default;
1391
IsSupported(const NodeDef * node) const1392 bool IsSupported(const NodeDef* node) const override {
1393 return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1394 }
1395
TrySimplify(NodeDef * node,string * simplified_node_name)1396 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1397 const string node_name = node->name();
1398 NodeDef* input;
1399 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1400 if (IsInPreserveSet(*input) ||
1401 NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1402 return Status::OK();
1403 }
1404 string new_op;
1405 if (IsEqual(*input)) {
1406 new_op = "NotEqual";
1407 } else if (IsNotEqual(*input)) {
1408 new_op = "Equal";
1409 } else if (IsLess(*input)) {
1410 new_op = "GreaterEqual";
1411 } else if (IsLessEqual(*input)) {
1412 new_op = "Greater";
1413 } else if (IsGreater(*input)) {
1414 new_op = "LessEqual";
1415 } else if (IsGreaterEqual(*input)) {
1416 new_op = "Less";
1417 }
1418 if (!new_op.empty()) {
1419 input->set_op(new_op);
1420 *simplified_node_name = input->name();
1421 }
1422 return Status::OK();
1423 }
1424 };
1425
1426 // This optimization hoists the common prefix of unary ops of the inputs to
1427 // concat out of the concat, for example:
1428 // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1429 // becomes
1430 // Exp(Sin(Concat([x, y, z]))).
1431 // Similarly, it will hoist the common postfix of unary ops into Split or
1432 // SplitV nodes, for example:
1433 // [Exp(Sin(y)) for y in Split(x)]
1434 // becomes
1435 // [y for y in Split(Exp(Sin(x))]
1436 //
1437 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1438 // on the concat/split node.
1439 // TODO(rmlarsen): Handle Enter/Exit.
1440 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1441 public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1442 explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1443 const ArithmeticOptimizerContext& ctx_ext)
1444 : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1445
1446 ~HoistCWiseUnaryChainsStage() override = default;
1447
1448 struct ChainLink {
1449 ChainLink() = default;
ChainLinktensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1450 ChainLink(NodeDef* _node, int _port_origin)
1451 : node(_node), port_origin(_port_origin) {}
1452 NodeDef* node; // Node in a chain.
1453 int port_origin; // Port on concat/split node from which this chain
1454 // originates.
1455
operator <tensorflow::grappler::__anon327bfa1e0111::HoistCWiseUnaryChainsStage::ChainLink1456 bool operator<(const ChainLink& other) const {
1457 if (port_origin < other.port_origin) {
1458 return true;
1459 } else if (port_origin > other.port_origin) {
1460 return false;
1461 } else {
1462 return node->name() < other.node->name();
1463 }
1464 }
1465 };
1466
1467 // We use an ordinary set sorted on port and node name, so the order, and
1468 // hence the node name used for the hoisted chain, will be deterministic.
1469 using ChainLinkSet = std::set<ChainLink>;
1470
IsSupported(const NodeDef * node) const1471 bool IsSupported(const NodeDef* node) const override {
1472 if (IsInPreserveSet(*node)) return false;
1473 if (IsConcat(*node) && node->attr().count("N") != 0) {
1474 const int n = node->attr().at("N").i();
1475 return n > 1 && FirstNInputsAreUnique(*node, n);
1476 } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1477 node->attr().count("num_split") != 0) {
1478 const int num_split = node->attr().at("num_split").i();
1479 if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1480 // TODO(rmlarsen): Remove this constraint when we have optimizations
1481 // in place for merging slices into splits.
1482 return false;
1483 }
1484 if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1485 // TODO(ezhulenev): Unary ops after Split might have a control path to
1486 // the Split node, and we currently do not properly handle cycles.
1487 return false;
1488 }
1489 return num_split > 1 && !IsAlreadyOptimized(*node);
1490 }
1491 return false;
1492 }
1493
TrySimplify(NodeDef * node,string * simplified_node_name)1494 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1495 node_is_concat_ = IsConcat(*node);
1496 int prefix_length;
1497 std::set<string> ctrl_inputs;
1498 ChainLinkSet tails;
1499 TF_RETURN_IF_ERROR(
1500 FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1501 if (prefix_length > 0 && !tails.empty()) {
1502 TF_RETURN_IF_ERROR(
1503 HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1504 }
1505 return Status::OK();
1506 }
1507
1508 private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1509 bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1510 if (n > node.input_size()) return false;
1511 absl::flat_hash_set<string> unique_inputs;
1512 const int start = node.op() == "Concat" ? 1 : 0;
1513 const int end = start + n;
1514 for (int i = start; i < end; ++i) {
1515 unique_inputs.insert(node.input(i));
1516 }
1517 int unique_input_size = unique_inputs.size();
1518 return unique_input_size == n;
1519 }
1520
1521 // Returns the length of the common unary chain of ops that can be
1522 // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1523 Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1524 ChainLinkSet* tails,
1525 std::set<string>* ctrl_inputs) const {
1526 *prefix_length = 0;
1527 // Follow the chains starting at each concat input or split output as long
1528 // as all the following conditions hold:
1529 // 1. The ops in all chains are the same.
1530 // 2. The ops are unary elementwise op.
1531 // 3. The op output has only a single consumer (concat only).
1532 ChainLinkSet cur_tails;
1533 TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1534 if (cur_tails.size() < 2) {
1535 return Status::OK();
1536 }
1537 ctrl_inputs->clear();
1538 bool stop = false;
1539 while (!stop && !cur_tails.empty() &&
1540 OpsAreSafeToHoist(root_node, cur_tails)) {
1541 // We found one more link that can be hoisted.
1542 ++(*prefix_length);
1543 tails->swap(cur_tails);
1544 GatherControlInputs(ctrl_inputs, *tails);
1545
1546 // Advance tail pointers to the next level.
1547 TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1548 }
1549 return Status::OK();
1550 }
1551
1552 // Hoists the chains to the other side of concat or split and attaches the
1553 // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1554 Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1555 std::set<string>* ctrl_inputs, NodeDef* root_node) {
1556 VLOG(3) << "Hoist unary op chain:"
1557 << " root=" << root_node->DebugString()
1558 << " prefix_length=" << prefix_length << " ctrl_inputs=["
1559 << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1560
1561 if (tails.empty()) {
1562 return Status::OK();
1563 }
1564 AddToOptimizationQueue(root_node);
1565 optimized_nodes_.insert(root_node->name());
1566 if (node_is_concat_) {
1567 AddControlInputs(ctrl_inputs, root_node);
1568 return HoistChainForConcat(prefix_length, tails, root_node);
1569 } else {
1570 return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1571 }
1572 }
1573
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1574 void GatherControlInputs(std::set<string>* ctrl_inputs,
1575 const ChainLinkSet& ops) const {
1576 for (const auto& link : ops) {
1577 const NodeDef* node = link.node;
1578 for (int i = node->input_size() - 1; i >= 0; --i) {
1579 const string& input = node->input(i);
1580 if (!IsControlInput(input)) break;
1581 ctrl_inputs->insert(input);
1582 }
1583 }
1584 }
1585
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1586 void AddControlInputs(std::set<string>* new_ctrl_inputs,
1587 NodeDef* node) const {
1588 for (int i = node->input_size() - 1; i >= 0; --i) {
1589 const string& existing_input = node->input(i);
1590 if (!IsControlInput(existing_input)) break;
1591 new_ctrl_inputs->erase(existing_input);
1592 }
1593 for (const string& new_input : *new_ctrl_inputs) {
1594 ctx().node_map->AddOutput(NodeName(new_input), node->name());
1595 node->add_input(new_input);
1596 }
1597 }
1598
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1599 Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1600 if (node_is_concat_) {
1601 // Handle concat nodes by looking backwards in the graph.
1602 TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1603 const int n = node.attr().at("N").i();
1604 const int start = node.op() == "Concat" ? 1 : 0;
1605 const int end = start + n;
1606 if (end > node.input_size()) {
1607 return errors::FailedPrecondition("Got attr N=", n,
1608 " without enough inputs.");
1609 }
1610 // Set up tail pointers to point to the immediate inputs to Concat.
1611 for (int input_port = start; input_port < end; ++input_port) {
1612 if (IsControlInput(node.input(input_port))) {
1613 return errors::FailedPrecondition(
1614 "Got control input ", node.input(input_port),
1615 " where normal input was expected.");
1616 }
1617 NodeDef* tail;
1618 TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1619 tails->insert(ChainLink(tail, input_port));
1620 }
1621 return Status::OK();
1622 } else {
1623 // Handle split nodes by looking forwards in the graph.
1624 const auto& outputs = ctx().node_map->GetOutputs(node.name());
1625 for (NodeDef* output : outputs) {
1626 if (IsControlInput(output->input(0))) continue;
1627 TensorId tensor_id = ParseTensorName(output->input(0));
1628 if (tensor_id.node() == node.name()) {
1629 tails->insert(ChainLink(output, tensor_id.index()));
1630 } else {
1631 // This output node has a non-control input other than the split node,
1632 // abort.
1633 tails->clear();
1634 return Status::OK();
1635 }
1636 }
1637 }
1638 return Status::OK();
1639 }
1640
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1641 bool OpsAreSafeToHoist(const NodeDef& root_node,
1642 const ChainLinkSet& ops) const {
1643 if (ops.empty()) return true;
1644 const NodeDef* op0 = ops.begin()->node;
1645 if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1646 for (const auto& link : ops) {
1647 const NodeDef* op = link.node;
1648 if (op->device() != root_node.device() || op->op() != op0->op() ||
1649 IsInPreserveSet(*op)) {
1650 return false;
1651 }
1652 if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1653 // TODO(rmlarsen): Allow outgoing control edges.
1654 return false;
1655 }
1656 // Do not hoist Relu if it can be fused with its predecessors. This is
1657 // important because remapping runs after arithmetic.
1658 if (IsRelu(*op) || IsRelu6(*op)) {
1659 NodeDef* operand = nullptr;
1660 if (!GetInputNode(op->input(0), &operand).ok()) {
1661 return false;
1662 }
1663 if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1664 return false;
1665 }
1666 }
1667 }
1668 return true;
1669 }
1670
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1671 Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1672 bool* stop) const {
1673 *stop = true;
1674 new_tails->clear();
1675 for (const auto& link : tails) {
1676 const NodeDef* tail = link.node;
1677 if (node_is_concat_) {
1678 if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1679 return Status::OK();
1680 }
1681 NodeDef* new_tail;
1682 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1683 // Remember original port.
1684 new_tails->insert(ChainLink(new_tail, link.port_origin));
1685 } else {
1686 for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1687 const TensorId tensor = ParseTensorName(new_tail->input(0));
1688 if (tensor.node() != tail->name()) {
1689 return Status::OK();
1690 }
1691 // Skip control outputs.
1692 if (tensor.index() >= 0) {
1693 // Remember original port.
1694 new_tails->insert(ChainLink(new_tail, link.port_origin));
1695 }
1696 }
1697 }
1698 }
1699 *stop = false;
1700 return Status::OK();
1701 }
1702
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1703 Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1704 NodeDef* concat_node) {
1705 const string& concat_name = concat_node->name();
1706 const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1707 for (const auto& link : tails) {
1708 NodeDef* tail = CHECK_NOTNULL(link.node);
1709 const int concat_port = link.port_origin;
1710 CHECK_GE(concat_port, 0);
1711 CHECK_LT(concat_port, concat_node->input_size());
1712 const string concat_input = concat_node->input(concat_port);
1713 // Hook the node following tail directly into the concat node.
1714 const string tail_input = tail->input(0);
1715 concat_node->set_input(concat_port, tail_input);
1716 ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1717
1718 if (concat_port == first_input) {
1719 // Update the consumers of concat to consume the end of the chain
1720 // instead.
1721 UpdateConsumers(concat_node, concat_input);
1722 // Reuse nodes in the first chain to process output of concat.
1723 tail->set_input(0, concat_name);
1724 ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1725 }
1726 }
1727 return Status::OK();
1728 }
1729
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1730 Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1731 std::set<string>* ctrl_inputs,
1732 NodeDef* split_node) {
1733 // Create a new chain before the split node to process the input tensor.
1734 const string& split_name = split_node->name();
1735 auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1736
1737 // We use the first tail node in the set as a template to get the list of
1738 // ops to apply (starting from the end).
1739 NodeDef* cur_tail = tails.begin()->node;
1740 NodeDef* cur_copy = AddCopyNode(
1741 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1742 cur_copy->clear_input();
1743
1744 // Update the split to take its input from the tail of the new chain.
1745 const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1746 const string orig_input = split_node->input(value_slot);
1747 split_node->set_input(value_slot, cur_copy->name());
1748 ctx().node_map->UpdateInput(split_node->name(), orig_input,
1749 cur_copy->name());
1750 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1751
1752 // Now walk backwards creating the rest of the chain.
1753 while (cur_tail != split_node) {
1754 NodeDef* new_copy = AddCopyNode(
1755 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1756 new_copy->clear_input();
1757 cur_copy->add_input(new_copy->name());
1758 ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1759 cur_copy = new_copy;
1760 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1761 }
1762 // Connect the original input to the head of the new chain.
1763 cur_copy->add_input(orig_input);
1764 ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1765 cur_copy->name());
1766 // Make sure all the control inputs are satisfied before running the first
1767 // node in the new chain.
1768 AddControlInputs(ctrl_inputs, cur_copy);
1769
1770 // Connect all consumers of the tail nodes directly to the
1771 // output port of Split from which the chain started.
1772 for (const auto& link : tails) {
1773 UpdateConsumers(link.node,
1774 link.port_origin == 0
1775 ? split_name
1776 : strings::StrCat(split_name, ":", link.port_origin));
1777 }
1778 return Status::OK();
1779 }
1780
1781 // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)1782 void UpdateConsumers(NodeDef* node, const string& new_input) {
1783 const string& node_name = node->name();
1784 const auto consumers = ctx().node_map->GetOutputs(node_name);
1785 for (NodeDef* consumer : consumers) {
1786 for (int i = 0; i < consumer->input_size(); ++i) {
1787 if (consumer->input(i) == node_name &&
1788 consumer->name() != NodeName(new_input)) {
1789 consumer->set_input(i, new_input);
1790 ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
1791 }
1792 }
1793 AddToOptimizationQueue(consumer);
1794 }
1795 }
1796
IsAlreadyOptimized(const NodeDef & node) const1797 bool IsAlreadyOptimized(const NodeDef& node) const {
1798 return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1799 }
1800
1801 private:
1802 bool node_is_concat_;
1803 std::unordered_set<string> optimized_nodes_;
1804 };
1805
1806 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1807 public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1808 explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1809 const ArithmeticOptimizerContext& ctx_ext)
1810 : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1811 ~RemoveIdempotentStage() override = default;
1812
IsSupported(const NodeDef * node) const1813 bool IsSupported(const NodeDef* node) const override {
1814 return node->input_size() == 1 && IsIdempotent(*node) &&
1815 !IsInPreserveSet(*node);
1816 }
1817
TrySimplify(NodeDef * node,string * simplified_node_name)1818 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1819 NodeDef* input;
1820 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1821 if (input->op() == node->op() && input->device() == node->device()) {
1822 *simplified_node_name = node->input(0);
1823 }
1824 return Status::OK();
1825 }
1826 };
1827
1828 // Performs the conversion:
1829 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1830 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1831 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1832 public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1833 explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1834 const ArithmeticOptimizerContext& ctx_ext)
1835 : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1836 ~SqrtDivToRsqrtMulStage() override = default;
1837
IsSupported(const NodeDef * node) const1838 bool IsSupported(const NodeDef* node) const override {
1839 // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1840 // for b == 0 would result in a / Inf instead of 0.
1841 return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1842 }
1843
TrySimplify(NodeDef * node,string * simplified_node_name)1844 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1845 NodeDef* y;
1846 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1847 // Optimize only if divisor is a Sqrt whose output is not being consumed
1848 // elsewhere.
1849 if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1850 (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1851 if (IsXdivy(*node)) {
1852 // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1853 node->set_op("MulNoNan");
1854 node->mutable_input()->SwapElements(0, 1);
1855 } else {
1856 // div(a, sqrt(b)) => mul(a, rsqrt(b))
1857 node->set_op("Mul");
1858 }
1859 y->set_op("Rsqrt");
1860 AddToOptimizationQueue(node);
1861 AddToOptimizationQueue(y);
1862 }
1863 return Status::OK();
1864 }
1865 };
1866
1867 // Performs the following conversion for real types:
1868 // Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1869 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1870 public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1871 explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1872 const ArithmeticOptimizerContext& ctx_ext)
1873 : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1874 ~FuseSquaredDiffStage() override = default;
1875
IsSupported(const NodeDef * node) const1876 bool IsSupported(const NodeDef* node) const override {
1877 return IsSquare(*node);
1878 }
1879
TrySimplify(NodeDef * node,string * simplified_node_name)1880 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1881 NodeDef* b;
1882 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1883 // Optimize only if base is a Sub whose output is not being consumed
1884 // elsewhere.
1885 if (IsSub(*b) && !IsInPreserveSet(*b) &&
1886 (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1887 // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1888 // invalid.
1889 const DataType type = GetDataTypeFromAttr(*b, "T");
1890 if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128))
1891 return Status::OK();
1892 node->set_op("Identity");
1893 b->set_op("SquaredDifference");
1894 AddToOptimizationQueue(node);
1895 AddToOptimizationQueue(b);
1896 }
1897 return Status::OK();
1898 }
1899 };
1900
1901 // Performs the conversion:
1902 // Log(Softmax(x)) => LogSoftmax(x)
1903 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1904 public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1905 explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1906 const ArithmeticOptimizerContext& ctx_ext)
1907 : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1908 ~LogSoftmaxStage() override = default;
1909
IsSupported(const NodeDef * node) const1910 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1911
TrySimplify(NodeDef * node,string * simplified_node_name)1912 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1913 NodeDef* x;
1914 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1915 // Optimize only if arg is a Softmax whose output is not being consumed
1916 // elsewhere.
1917 if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1918 (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1919 // Log(Softmax(x)) => LogSoftmax(Identity(x))
1920 node->set_op("LogSoftmax");
1921 x->set_op("Identity");
1922 AddToOptimizationQueue(node);
1923 AddToOptimizationQueue(x);
1924 }
1925 return Status::OK();
1926 }
1927 };
1928
1929 // Bypass redundant reshape nodes:
1930 //
1931 // Reshape Reshape <-+
1932 // ^ |
1933 // | |
1934 // Reshape becomes Reshape |
1935 // ^ |
1936 // | |
1937 // input input ---+
1938 //
1939 // Additionally, Reshape and BroadcastTo nodes where the
1940 // input and target shapes are equal are bypassed.
1941 //
1942 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1943 public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1944 explicit RemoveRedundantReshapeOrBroadcastTo(
1945 const GraphOptimizerContext& ctx,
1946 const ArithmeticOptimizerContext& ctx_ext)
1947 : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1948 ctx_ext) {}
1949 ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1950
IsSupported(const NodeDef * node) const1951 bool IsSupported(const NodeDef* node) const override {
1952 return (IsReshape(*node) || IsBroadcastTo(*node)) &&
1953 !IsInPreserveSet(*node);
1954 }
1955
TrySimplify(NodeDef * node,string * simplified_node_name)1956 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1957 NodeDef* input;
1958 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1959
1960 // 1. Bypass reshape followed by reshape.
1961 if (IsReshape(*node) && IsReshape(*input) && !IsInPreserveSet(*input)) {
1962 ForwardControlDependencies(node, {input});
1963 node->set_input(0, input->input(0));
1964 ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
1965 *simplified_node_name = node->name();
1966 AddToOptimizationQueue(node);
1967 return Status::OK();
1968 }
1969
1970 // 2. If the reshape is a no-op, forward its input to its consumers, unless
1971 // it anchors a control dependency since we want to make sure that control
1972 // dependency is triggered.
1973 if (InputMatchesTargetShape(*node) && !HasControlInputs(*node)) {
1974 *simplified_node_name = node->input(0);
1975 return Status::OK();
1976 }
1977
1978 return Status::OK();
1979 }
1980
1981 private:
1982 // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)1983 bool InputMatchesTargetShape(const NodeDef& reshape) {
1984 const OpInfo::TensorProperties* reshape_props;
1985 const OpInfo::TensorProperties* input_props;
1986
1987 if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
1988 !GetTensorProperties(reshape.input(0), &input_props).ok()) {
1989 return false;
1990 }
1991
1992 return ShapesSymbolicallyEqual(input_props->shape(),
1993 reshape_props->shape());
1994 }
1995 };
1996
1997 // Reorder casting and value-preserving ops if beneficial.
1998 //
1999 // Original motivation: A common pattern after the layout optimizer is
2000 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2001 // is beneficial to reorder the cast and the transpose to make the transpose
2002 // process smaller amount of data. More generally, this optimization converts
2003 // Op(Cast(tensor, dst_type))
2004 // to
2005 // Cast(Op(tensor), dst_type)
2006 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2007 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2008 // this optimization converts
2009 // Cast(Op(tensor), dst_type)
2010 // to
2011 // Op(Cast(tensor, dst_type))
2012 // when sizeof(tensor.type) > sizeof(dst_type)
2013 //
2014 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2015 public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2016 explicit ReorderCastLikeAndValuePreserving(
2017 const GraphOptimizerContext& ctx,
2018 const ArithmeticOptimizerContext& ctx_ext)
2019 : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2020 ctx_ext) {}
2021 ~ReorderCastLikeAndValuePreserving() override = default;
2022
IsSupported(const NodeDef * node) const2023 bool IsSupported(const NodeDef* node) const override {
2024 return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2025 !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2026 !IsControlFlow(*node) && !IsInPreserveSet(*node);
2027 }
2028
TrySimplify(NodeDef * consumer,string * simplified_node_name)2029 Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2030 NodeDef* producer;
2031 TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2032 const bool producer_is_cast = IsCastLike(*producer);
2033 const bool can_optimize =
2034 !IsCheckNumerics(*producer) &&
2035 ((producer_is_cast && IsValuePreserving(*consumer)) ||
2036 (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2037 if (!can_optimize || IsControlFlow(*producer) ||
2038 IsInPreserveSet(*producer) ||
2039 producer->device() != consumer->device()) {
2040 return Status::OK();
2041 }
2042
2043 const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2044 const OpDef* cast_like_op_def = nullptr;
2045 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2046 &cast_like_op_def));
2047 DataType cast_src_type;
2048 TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2049 &cast_src_type));
2050 DataType cast_dst_type;
2051 TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2052 &cast_dst_type));
2053 if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2054 return Status::OK();
2055 } else if (producer_is_cast &&
2056 DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2057 return Status::OK();
2058 } else if (!producer_is_cast &&
2059 DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2060 return Status::OK();
2061 }
2062
2063 // Check that nodes were not already optimized.
2064 const string optimized_producer_name = OptimizedNodeName(
2065 ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2066 const string optimized_consumer_name = OptimizedNodeName(
2067 ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2068 const bool is_already_optimized =
2069 ctx().node_map->NodeExists(optimized_consumer_name) ||
2070 ctx().node_map->NodeExists(optimized_producer_name);
2071 if (is_already_optimized) {
2072 return Status::OK();
2073 }
2074
2075 // Add copies of consumer and producer in reverse order.
2076 NodeDef* input;
2077 TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2078 // Create new producer node.
2079 NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2080 new_producer->set_input(0, producer->input(0));
2081 ctx().node_map->AddOutput(input->name(), new_producer->name());
2082
2083 // Create new consumer node.
2084 NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2085 new_consumer->set_input(0, new_producer->name());
2086
2087 NodeDef* new_value_preserving =
2088 producer_is_cast ? new_producer : new_consumer;
2089 const DataType new_input_type =
2090 producer_is_cast ? cast_src_type : cast_dst_type;
2091 // Update the input type of the value-preserving node. The input and
2092 // output types of the cast-like nodes remain the same.
2093 TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2094 // Make sure there is a kernel registered for the value preserving op
2095 // with the new input type.
2096 TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2097 ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2098
2099 AddToOptimizationQueue(new_producer);
2100 *simplified_node_name = new_consumer->name();
2101
2102 return Status::OK();
2103 }
2104
2105 private:
2106 // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2107 Status SetInputType(DataType dtype, NodeDef* node) {
2108 const OpDef* op_def = nullptr;
2109 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2110 const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2111 const string& type_attr_name = input_arg.type_attr();
2112 if (type_attr_name.empty()) {
2113 if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2114 return errors::InvalidArgument("Could not set input type of ",
2115 node->op(), " op to ",
2116 DataTypeString(dtype));
2117 } else {
2118 // Op has fixed input type that already matches dtype.
2119 return Status::OK();
2120 }
2121 }
2122 SetDataTypeToAttr(dtype, type_attr_name, node);
2123 return Status::OK();
2124 }
2125 // This optimization can be dangerous on devices other than CPU and
2126 // GPU. The transpose might not be implemented for image.type, or
2127 // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2128 bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2129 using absl::StrContains;
2130
2131 string task;
2132 string device;
2133
2134 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2135 (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2136 }
2137
IsFixedSizeType(DataType dtype)2138 bool IsFixedSizeType(DataType dtype) {
2139 return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2140 !kQuantizedTypes.Contains(dtype);
2141 }
2142 };
2143
2144 // Fold a multiply of a scalar into the following convolution. This folding
2145 // can jump across nodes that merely reorders data (such as reshape and
2146 // transpose). For example, we can optimize
2147 //
2148 //
2149 // Conv2D Conv2D
2150 // / \ / \
2151 // Transpose weights* -> Transpose Mul
2152 // | | / \
2153 // Mul | weights scale
2154 // / \ |
2155 // input scale** input
2156 //
2157 // *) weights must be a const
2158 // **) scale must be a const scalar
2159 //
2160 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2161 // constant-folded, also weights tend to be smaller than the activations.
2162 //
2163 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2164 // Conv?DBackpropInput.
2165 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2166 public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2167 explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2168 const ArithmeticOptimizerContext& ctx_ext)
2169 : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2170 ~FoldMultiplyIntoConv() override = default;
2171
IsSupported(const NodeDef * node) const2172 bool IsSupported(const NodeDef* node) const override {
2173 return IsConv2D(*node) || IsConv3D(*node);
2174 }
2175
TrySimplify(NodeDef * node,string * simplified_node_name)2176 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2177 #define TF_RETURN_IF_TRUE(...) \
2178 if ((__VA_ARGS__)) return Status::OK()
2179
2180 NodeDef* conv = node;
2181
2182 NodeDef* weights;
2183 TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2184
2185 // Fold the multiply to conv only when the weights are constant, so the
2186 // multiply can be constant-folded.
2187 //
2188 // TODO(jingyue): When the weights aren't constant, this should also help
2189 // performance a bit and memory usage a lot, since the weights tend to be
2190 // smaller than the activations.
2191 TF_RETURN_IF_TRUE(!IsConstant(*weights));
2192
2193 // Verify that this node was not already optimized.
2194 const string scaled_weights_node_name =
2195 OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2196 strings::StrCat("scaled", "_", conv->name()));
2197
2198 TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2199
2200 // Find the tail of value preserving chain entering the Conv node.
2201 NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2202 *ctx().nodes_to_preserve);
2203
2204 NodeDef* source;
2205 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2206
2207 // Check that value preserving chain is the only consumer of the Mul output.
2208 TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2209 TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2210 // And that Mul is not in the preserve set.
2211 TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2212
2213 const NodeDef* mul = source;
2214 int input_idx = 0;
2215 int scale_idx = 1;
2216 NodeDef* scale; // scalar multiplier for the input tensor
2217 NodeDef* input;
2218 TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2219 TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2220 if (!IsConstant(*scale) && IsConstant(*input)) {
2221 VLOG(3) << "Swapped inputs to mul";
2222 std::swap(scale_idx, input_idx);
2223 std::swap(scale, input);
2224 }
2225 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2226
2227 // Check that one of the inputs to mul is a constant scalar.
2228 const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2229 bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2230 scale_tensor.tensor_shape().dim_size() == 0;
2231 TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2232
2233 // Check that 'scale * weight' can be const folded.
2234 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2235 TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2236 TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2237 TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2238 weights->attr().at("dtype").type());
2239
2240 // At this point all preconditions are met, and we safely do the rewrite.
2241 VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2242 << " mul=" << mul->name() << " weights=" << weights->name();
2243
2244 // Create new node `scaled_weights`.
2245 NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2246 scaled_weights->set_op(source->op());
2247 scaled_weights->set_device(weights->device());
2248 (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2249 AddToOptimizationQueue(scaled_weights);
2250
2251 // Link in its inputs.
2252 scaled_weights->add_input(conv->input(1));
2253 ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2254 scaled_weights->add_input(mul->input(scale_idx));
2255 ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2256 ForwardControlDependencies(scaled_weights, {source});
2257
2258 // Update `conv`'s weights to `scaled_weights`.
2259 conv->set_input(1, scaled_weights->name());
2260 ctx().node_map->UpdateInput(conv->name(), weights->name(),
2261 scaled_weights->name());
2262 AddToOptimizationQueue(conv);
2263
2264 // Update `tail` node to bypass `mul` because it's folded to the weights.
2265 tail->set_input(0, mul->input(input_idx));
2266 ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2267 AddToOptimizationQueue(tail);
2268 *simplified_node_name = conv->name();
2269
2270 return Status::OK();
2271 #undef TF_RETURN_IF_TRUE
2272 }
2273 };
2274
2275 // Fold Transpose into matrix multiplication.
2276 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2277 public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2278 explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2279 const ArithmeticOptimizerContext& ctx_ext)
2280 : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2281 ~FoldTransposeIntoMatMul() override = default;
2282
IsSupported(const NodeDef * node) const2283 bool IsSupported(const NodeDef* node) const override {
2284 return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2285 }
2286
TrySimplify(NodeDef * node,string * simplified_node_name)2287 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2288 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2289 const string optimized_node_name = OptimizedNodeName(matmul);
2290 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2291
2292 NodeDef* a;
2293 NodeDef* b;
2294 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2295 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2296
2297 bool is_complex = false;
2298 if (node->op() != "SparseMatMul") {
2299 const DataType type = GetDataTypeFromAttr(*node, "T");
2300 is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2301 }
2302
2303 const std::set<string> foldable_transpose_ops =
2304 !is_complex
2305 ? std::set<string>{"ConjugateTranspose", "Transpose"}
2306 : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2307 : std::set<string>{"Transpose"});
2308
2309 const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2310 IsInnerMatrixTransposeNode(*a, ctx().node_map);
2311 const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2312 IsInnerMatrixTransposeNode(*b, ctx().node_map);
2313 if (!a_is_foldable && !b_is_foldable) return Status::OK();
2314
2315 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2316
2317 if (a_is_foldable) {
2318 const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2319 FlipBooleanAttr(attr_a, new_op);
2320 new_op->set_input(0, a->input(0));
2321 ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2322 } else {
2323 ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2324 }
2325
2326 if (b_is_foldable) {
2327 const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2328 FlipBooleanAttr(attr_b, new_op);
2329 new_op->set_input(1, b->input(0));
2330 ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2331 } else {
2332 ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2333 }
2334
2335 std::vector<const NodeDef*> deps_to_forward = {node};
2336 if (a_is_foldable) deps_to_forward.push_back(a);
2337 if (b_is_foldable) deps_to_forward.push_back(b);
2338 ForwardControlDependencies(new_op, deps_to_forward);
2339 *simplified_node_name = new_op->name();
2340
2341 return Status::OK();
2342 }
2343
2344 private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2345 void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2346 const bool old_value =
2347 !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2348 (*node->mutable_attr())[attr_name].set_b(!old_value);
2349 }
2350
2351 template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2352 bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2353 const T n = perm.size();
2354 if (n < 2) {
2355 return false;
2356 }
2357 for (T i = 0; i < n - 2; ++i) {
2358 if (perm[i] != i) {
2359 return false;
2360 }
2361 }
2362 return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2363 }
2364
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2365 bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2366 const NodeMap* node_map) {
2367 if (transpose_node.op() != "Transpose" &&
2368 transpose_node.op() != "ConjugateTranspose") {
2369 return false;
2370 }
2371 const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2372 std::vector<int> perm32;
2373 if (ValuesFromConstNode(*perm_node, &perm32)) {
2374 return IsInnerMatrixTranspose(perm32);
2375 }
2376 std::vector<int64> perm64;
2377 if (ValuesFromConstNode(*perm_node, &perm64)) {
2378 return IsInnerMatrixTranspose(perm64);
2379 }
2380 return false;
2381 }
2382 };
2383
2384 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2385 public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2386 explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2387 const ArithmeticOptimizerContext& ctx_ext)
2388 : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2389 ~FoldConjugateIntoTranspose() override = default;
2390
IsSupported(const NodeDef * node) const2391 bool IsSupported(const NodeDef* node) const override {
2392 return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2393 }
2394
TrySimplify(NodeDef * node,string * simplified_node_name)2395 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2396 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2397 const string optimized_node_name = OptimizedNodeName(matmul);
2398 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2399
2400 NodeDef* input;
2401 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2402
2403 const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2404 const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2405
2406 if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2407 IsConj(*conj_op)) {
2408 NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2409
2410 // Flip the type of transpose op to absorb the conjugation.
2411 new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2412 : "Transpose");
2413 new_op->set_input(0, input->input(0));
2414 ctx().node_map->UpdateInput(new_op->name(), node->name(),
2415 input->input(0));
2416 ForwardControlDependencies(new_op, {node, input});
2417 *simplified_node_name = new_op->name();
2418 }
2419
2420 return Status::OK();
2421 }
2422 };
2423
2424 // Replace Mul node with identical inputs with a Square.
2425 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2426 public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2427 explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2428 const ArithmeticOptimizerContext& ctx_ext)
2429 : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2430 ~ReplaceMulWithSquare() override = default;
2431
IsSupported(const NodeDef * node) const2432 bool IsSupported(const NodeDef* node) const override {
2433 return IsAnyMul(*node) && node->input(0) == node->input(1);
2434 }
2435
TrySimplify(NodeDef * node,string * simplified_node_name)2436 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2437 const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2438 const string optimized_node_name = OptimizedNodeName(mul);
2439 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2440
2441 const DataType type = GetDataTypeFromAttr(*node, "T");
2442 bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2443
2444 if (!is_complex || NodeIsOnCpu(*node)) {
2445 NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2446 new_square_node->set_op("Square");
2447 for (int i = 1; i < new_square_node->input_size(); ++i) {
2448 new_square_node->set_input(i - 1, new_square_node->input(i));
2449 }
2450 new_square_node->mutable_input()->RemoveLast();
2451 for (const string& input : new_square_node->input()) {
2452 ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2453 }
2454 *simplified_node_name = new_square_node->name();
2455 }
2456
2457 return Status::OK();
2458 }
2459 };
2460
2461 // Simplify aggregation (e.g. AddN) nodes:
2462 //
2463 // 1. Discard aggregate nodes with a single input and no control dependencies.
2464 //
2465 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
2466 // deduping or other rewrites) so we can get rid of the sum entirely.
2467 //
2468 // The expression (using AddN as an example of an aggregate op):
2469 // AddN(x, x, x, ... ,x)
2470 // <-- N terms -->
2471 // can be rewritten to:
2472 // Mul(Const(N), x))
2473 //
2474 class SimplifyAggregation : public ArithmeticOptimizerStage {
2475 public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2476 explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
2477 const ArithmeticOptimizerContext& ctx_ext)
2478 : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
2479 ~SimplifyAggregation() override = default;
2480
IsSupported(const NodeDef * node) const2481 bool IsSupported(const NodeDef* node) const override {
2482 return IsAggregate(*node) && HasRegularInputs(*node) &&
2483 GetDataTypeFromAttr(*node, "T") !=
2484 DT_VARIANT; // TODO(b/119787146): Enable for variants.
2485 }
2486
TrySimplify(NodeDef * node,string * simplified_node_name)2487 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2488 // 1. Discard aggregate nodes with a single input and no control deps.
2489 if (node->input_size() == 1) {
2490 *simplified_node_name = node->input(0);
2491 return Status::OK();
2492 }
2493
2494 // 2. Rewrite aggregations of N >= 2 identical terms.
2495
2496 // All non-control inputs must be identical.
2497 bool all_equal = true;
2498 int num_inputs = 1;
2499 for (int i = 1; i < node->input_size(); ++i) {
2500 if (IsControlInput(node->input(i))) break;
2501 ++num_inputs;
2502 if (node->input(i) != node->input(0)) {
2503 all_equal = false;
2504 break;
2505 }
2506 }
2507 if (!all_equal) return Status::OK();
2508
2509 // And node should not be optimized earlier.
2510 const NodeScopeAndName node_scope_and_name =
2511 ParseNodeScopeAndName(node->name());
2512 const string optimized_const_name =
2513 OptimizedNodeName(node_scope_and_name, "Const");
2514 const string optimized_mul_name =
2515 OptimizedNodeName(node_scope_and_name, "Mul");
2516
2517 bool is_already_optimized =
2518 ctx().node_map->NodeExists(optimized_const_name) ||
2519 ctx().node_map->NodeExists(optimized_mul_name);
2520
2521 if (is_already_optimized) return Status::OK();
2522
2523 // At this point all preconditions are met, and we safely do the rewrite.
2524 VLOG(3) << "Simplify aggregation with identical inputs: node="
2525 << node->name() << " num_inputs=" << num_inputs;
2526
2527 // 1. Create constant node with value N.
2528 const auto type = GetDataTypeFromAttr(*node, "T");
2529 Tensor t(type, TensorShape({}));
2530 Status status = SetTensorValue(type, num_inputs, &t);
2531 if (!status.ok()) {
2532 return errors::Internal("Failed to create const node: ",
2533 status.error_message());
2534 }
2535
2536 TensorValue value(&t);
2537 NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
2538 status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
2539 new_const_node);
2540 if (!status.ok()) {
2541 return errors::Internal("Failed to create const node: ",
2542 status.error_message());
2543 }
2544 new_const_node->set_device(node->device());
2545 MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
2546 ctx().optimized_graph, ctx().node_map);
2547 AddToOptimizationQueue(new_const_node);
2548
2549 // 2. Replace the aggregate node with Mul(Const(N), x).
2550 NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
2551 new_mul_node->set_op("Mul");
2552 new_mul_node->set_device(node->device());
2553 SetDataTypeToAttr(type, "T", new_mul_node);
2554 new_mul_node->add_input(new_const_node->name());
2555 ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
2556 new_mul_node->add_input(node->input(0));
2557 ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
2558
2559 ForwardControlDependencies(new_mul_node, {node});
2560 *simplified_node_name = new_mul_node->name();
2561
2562 return Status::OK();
2563 }
2564 };
2565
2566 class ConvertPowStage : public ArithmeticOptimizerStage {
2567 public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2568 explicit ConvertPowStage(const GraphOptimizerContext& ctx,
2569 const ArithmeticOptimizerContext& ctx_ext)
2570 : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
2571
IsSupported(const NodeDef * node) const2572 bool IsSupported(const NodeDef* node) const override {
2573 return IsPow(*node) &&
2574 ctx().graph_properties->HasOutputProperties(node->name()) &&
2575 ctx().graph_properties->HasInputProperties(node->name());
2576 }
2577
TrySimplify(NodeDef * node,string * simplified_node_name)2578 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2579 Tensor pow;
2580 if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
2581 complex128 prev, curr;
2582 for (int i = 0; i < pow.NumElements(); ++i) {
2583 if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
2584 // input data type is not supported by Pow. Skip.
2585 return Status::OK();
2586 }
2587 if (i != 0 && curr != prev) {
2588 // pow has different values on different elements. Skip.
2589 return Status::OK();
2590 }
2591 prev = curr;
2592 }
2593 NodeDef *x, *y;
2594 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
2595 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
2596 const auto& value_props =
2597 ctx().graph_properties->GetInputProperties(node->name())[0];
2598 const TensorShapeProto& output_shape =
2599 ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
2600 if (curr == complex128(2, 0)) {
2601 node->set_op("Square");
2602 node->set_input(1, AsControlDependency(y->name()));
2603 AddToOptimizationQueue(node);
2604 AddToOptimizationQueue(y);
2605 } else if (curr == complex128(3, 0)) {
2606 // TODO(courbet): Use 'Cube' when it's added to TF ops.
2607 if (NodeIsOnCpu(*node)) {
2608 // We create an inner square node: inner_square = square(x)
2609 const NodeScopeAndName scope_and_name =
2610 ParseNodeScopeAndName(node->name());
2611 const string inner_square_name =
2612 OptimizedNodeName(scope_and_name, "_inner");
2613 NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
2614 if (inner_square_node == nullptr) {
2615 inner_square_node = AddCopyNode(inner_square_name, node);
2616 inner_square_node->set_op("Square");
2617 inner_square_node->mutable_input()->RemoveLast();
2618 }
2619 ctx().node_map->AddOutput(x->name(), inner_square_node->name());
2620 // We modify `node`: node = mul(x, inner_square);
2621 node->set_op("Mul");
2622 node->set_input(1, inner_square_node->name());
2623 node->add_input(AsControlDependency(y->name()));
2624
2625 AddToOptimizationQueue(node);
2626 AddToOptimizationQueue(inner_square_node);
2627 AddToOptimizationQueue(y);
2628 }
2629 } else if (curr == complex128(1, 0) &&
2630 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2631 // Pow could be used to broadcast, so make sure the shapes of the two
2632 // arguments are identical before replacing Pow with Identity.
2633 node->set_op("Identity");
2634 node->set_input(1, AsControlDependency(y->name()));
2635 AddToOptimizationQueue(node);
2636 AddToOptimizationQueue(y);
2637 } else if (curr == complex128(0.5, 0)) {
2638 node->set_op("Sqrt");
2639 node->set_input(1, AsControlDependency(y->name()));
2640 AddToOptimizationQueue(node);
2641 AddToOptimizationQueue(y);
2642 } else if (curr == complex128(0, 0) &&
2643 ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
2644 PartialTensorShape(output_shape).IsFullyDefined()) {
2645 const auto dtype = node->attr().at("T").type();
2646 Tensor ones(dtype, output_shape);
2647 for (int i = 0; i < ones.NumElements(); ++i) {
2648 TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
2649 }
2650 node->set_op("Const");
2651 (*node->mutable_attr())["dtype"].set_type(dtype);
2652 node->mutable_attr()->erase("T");
2653 ones.AsProtoTensorContent(
2654 (*node->mutable_attr())["value"].mutable_tensor());
2655 node->set_input(0, AsControlDependency(x->name()));
2656 node->set_input(1, AsControlDependency(y->name()));
2657 AddToOptimizationQueue(node);
2658 AddToOptimizationQueue(x);
2659 AddToOptimizationQueue(y);
2660 } else if (curr == complex128(-0.5, 0)) {
2661 node->set_op("Rsqrt");
2662 node->set_input(1, AsControlDependency(y->name()));
2663 AddToOptimizationQueue(node);
2664 AddToOptimizationQueue(y);
2665 } else if (curr == complex128(-1, 0)) {
2666 node->set_op("Reciprocal");
2667 node->set_input(1, AsControlDependency(y->name()));
2668 AddToOptimizationQueue(node);
2669 AddToOptimizationQueue(y);
2670 }
2671 return Status::OK();
2672 }
2673
2674 private:
SetElementToOne(int i,Tensor * t)2675 Status SetElementToOne(int i, Tensor* t) {
2676 switch (t->dtype()) {
2677 case DT_INT32:
2678 t->flat<int32>()(i) = 1;
2679 return Status::OK();
2680 case DT_INT64:
2681 t->flat<int64>()(i) = 1L;
2682 return Status::OK();
2683 case DT_FLOAT:
2684 t->flat<float>()(i) = 1.0f;
2685 return Status::OK();
2686 case DT_DOUBLE:
2687 t->flat<double>()(i) = 1.0;
2688 return Status::OK();
2689 case DT_COMPLEX64:
2690 t->flat<complex64>()(i) = complex64(1);
2691 return Status::OK();
2692 case DT_COMPLEX128:
2693 t->flat<complex128>()(i) = complex128(1);
2694 return Status::OK();
2695 default:
2696 return errors::InvalidArgument("Invalid data type: ", t->dtype());
2697 }
2698 }
2699 };
2700
2701 class ConvertLog1pStage : public ArithmeticOptimizerStage {
2702 public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2703 explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
2704 const ArithmeticOptimizerContext& ctx_ext)
2705 : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
2706 ~ConvertLog1pStage() override = default;
2707
IsSupported(const NodeDef * node) const2708 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
2709
TrySimplify(NodeDef * node,string * simplified_node_name)2710 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2711 NodeDef* input;
2712 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2713 if (!IsAdd(*input)) {
2714 return Status::OK();
2715 }
2716
2717 if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
2718 return Status::OK();
2719 }
2720
2721 bool modified = false;
2722 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
2723 if (!modified) {
2724 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
2725 }
2726 if (modified) {
2727 *simplified_node_name = node->name();
2728 }
2729 return Status::OK();
2730 }
2731
2732 private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)2733 Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
2734 bool* modified) {
2735 const auto& t =
2736 ctx().graph_properties->GetInputProperties(add_node->name())[i];
2737 const auto& c =
2738 ctx().graph_properties->GetInputProperties(add_node->name())[j];
2739 for (int k = 0; k < c.shape().dim_size(); ++k) {
2740 // Skip if c shape is not fully determined.
2741 if (c.shape().dim(k).size() < 0) {
2742 return Status::OK();
2743 }
2744 }
2745 TensorShapeProto broadcast_shape;
2746 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2747 return Status::OK();
2748 }
2749 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2750 // skip if the non-constant tensor doesn't have the same shape after
2751 // broadcast.
2752 return Status::OK();
2753 }
2754 Tensor constant;
2755 if (GetTensorFromConstNode(add_node->input(j), &constant)) {
2756 complex128 element;
2757 // TODO(rmlarsen): Refactor the more general IsOnes from
2758 // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
2759 // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
2760 // and Neg(-1) to 1.
2761 for (int k = 0; k < constant.NumElements(); ++k) {
2762 if (!GetElementUnexhaustive(constant, k,
2763 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2764 DT_COMPLEX64, DT_COMPLEX128},
2765 &element)) {
2766 // input data type is not supported by log1p. Skip.
2767 return Status::OK();
2768 }
2769 if (element != complex128(1)) {
2770 // current element is not 1. Skip.
2771 return Status::OK();
2772 }
2773 }
2774 NodeDef *x, *y;
2775 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
2776 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
2777 node->set_op("Log1p");
2778 node->set_input(0, add_node->input(i));
2779 node->add_input(AsControlDependency(y->name()));
2780 ForwardControlDependencies(node, {add_node});
2781
2782 AddToOptimizationQueue(node);
2783 AddToOptimizationQueue(add_node);
2784 AddToOptimizationQueue(x);
2785 AddToOptimizationQueue(y);
2786 *modified = true;
2787 }
2788 return Status::OK();
2789 }
2790 };
2791
2792 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
2793 public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2794 explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
2795 const ArithmeticOptimizerContext& ctx_ext)
2796 : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
2797 ~ConvertExpm1Stage() override = default;
2798
IsSupported(const NodeDef * node) const2799 bool IsSupported(const NodeDef* node) const override {
2800 if (!IsSub(*node)) return false;
2801
2802 NodeDef* input;
2803 if (!GetInputNode(node->input(0), &input).ok()) return false;
2804
2805 return IsExp(*input);
2806 }
2807
TrySimplify(NodeDef * node,string * simplified_node_name)2808 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2809 if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
2810 return Status::OK();
2811 }
2812 const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
2813 const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
2814 TensorShapeProto broadcast_shape;
2815 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2816 return Status::OK();
2817 }
2818 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2819 // skip if the non-constant tensor doesn't have the same shape after
2820 // broadcast.
2821 return Status::OK();
2822 }
2823 Tensor constant;
2824 if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
2825 // TODO(rmlarsen): Use the more general IsOnes helper here.
2826 complex128 element;
2827 for (int k = 0; k < constant.NumElements(); ++k) {
2828 if (!GetElementUnexhaustive(constant, k,
2829 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2830 DT_COMPLEX64, DT_COMPLEX128},
2831 &element)) {
2832 // input data type is not supported by expm1. Skip.
2833 return Status::OK();
2834 }
2835 if (element != complex128(1)) {
2836 // current element is not 1. Skip.
2837 return Status::OK();
2838 }
2839 }
2840 NodeDef* exp;
2841 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
2842 NodeDef *exp_input, *ones;
2843 TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
2844 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2845 node->set_op("Expm1");
2846 node->set_input(0, exp->input(0));
2847 node->set_input(1, AsControlDependency(ones->name()));
2848 ForwardControlDependencies(node, {exp});
2849
2850 AddToOptimizationQueue(node);
2851 AddToOptimizationQueue(exp);
2852 AddToOptimizationQueue(exp_input);
2853 AddToOptimizationQueue(ones);
2854 *simplified_node_name = node->name();
2855 return Status::OK();
2856 }
2857 };
2858
2859 // Performs conversions like:
2860 // Max(Sqrt(x)) => Sqrt(Max(x))
2861 // Checks for a max/min reduction over element-wise monotonic functions, such
2862 // as Sqrt, Sigmoid, Tanh, etc.
2863 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
2864 public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2865 explicit OptimizeMaxOrMinOfMonotonicStage(
2866 const GraphOptimizerContext& ctx,
2867 const ArithmeticOptimizerContext& ctx_ext)
2868 : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
2869 ctx_ext) {}
2870 ~OptimizeMaxOrMinOfMonotonicStage() override = default;
2871
IsSupported(const NodeDef * node) const2872 bool IsSupported(const NodeDef* node) const override {
2873 return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
2874 IsArgMax(*node) || IsArgMin(*node);
2875 }
2876
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)2877 Status TrySimplify(NodeDef* reduction_node,
2878 string* simplified_node_name) override {
2879 if (IsInPreserveSet(*reduction_node)) {
2880 return Status::OK();
2881 }
2882
2883 NodeDef* inner_function;
2884 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
2885
2886 NodeDef* inner_function_input = nullptr;
2887 if (inner_function->input_size() > 0) {
2888 TF_RETURN_IF_ERROR(
2889 GetInputNode(inner_function->input(0), &inner_function_input));
2890 }
2891
2892 // Optimize only if:
2893 // 0. inner_function is not in the preserve set,
2894 // 1. inner_function's Op is element-wise monotonic
2895 // 2. inner_function's output is not being consumed elsewhere.
2896 // 3. is monotonic increasing if reduction_node is a pooling operation
2897 // since we don't have MinPool operations.
2898 // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
2899 // or BiasAdd. This pattern will be fused later by remapper.
2900 auto can_be_fused_by_remapper = [](const NodeDef& consumer,
2901 const NodeDef& producer) -> bool {
2902 if (IsRelu(consumer) || IsRelu6(consumer)) {
2903 if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
2904 return true;
2905 }
2906 }
2907 return false;
2908 };
2909 bool is_non_decreasing = false;
2910 if (!IsInPreserveSet(*inner_function) &&
2911 IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
2912 ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
2913 (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
2914 !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
2915 // Swap the first inputs of the inner function Op & the reduction Op.
2916 NodeDef* inner_input;
2917 TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
2918 reduction_node->set_input(0, inner_input->name());
2919 ctx().node_map->UpdateInput(reduction_node->name(),
2920 inner_function->name(), inner_input->name());
2921 inner_function->set_input(0, reduction_node->name());
2922 UpdateConsumers(reduction_node, inner_function->name());
2923 ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
2924 reduction_node->name());
2925 if (!is_non_decreasing) {
2926 // Flip Min<->Max if the function is non-increasing, e.g.
2927 // Max(Neg(x)) = Neg(Min(x)).
2928 const string opposite = FlipMinMax(*reduction_node);
2929 reduction_node->set_op(opposite);
2930 }
2931
2932 if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
2933 // ArgMax(Sqrt(x)) = ArgMax(x)
2934 inner_function->set_op("Identity");
2935 }
2936
2937 AddToOptimizationQueue(reduction_node);
2938 AddToOptimizationQueue(inner_function);
2939 AddToOptimizationQueue(inner_input);
2940 }
2941 return Status::OK();
2942 }
2943
UpdateConsumers(NodeDef * node,const string & new_input)2944 void UpdateConsumers(NodeDef* node, const string& new_input) {
2945 const string& node_name = node->name();
2946 const auto consumers = ctx().node_map->GetOutputs(node_name);
2947 for (NodeDef* consumer : consumers) {
2948 for (int i = 0; i < consumer->input_size(); ++i) {
2949 if (consumer->input(i) == node_name &&
2950 consumer->name() != NodeName(new_input)) {
2951 consumer->set_input(i, new_input);
2952 ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
2953 }
2954 }
2955 AddToOptimizationQueue(consumer);
2956 }
2957 }
2958
2959 private:
FlipMinMax(const NodeDef & node)2960 string FlipMinMax(const NodeDef& node) {
2961 const string& op = node.op();
2962 if (IsAnyMax(node) || IsArgMax(node)) {
2963 return str_util::StringReplace(op, "Max", "Min", false);
2964 } else {
2965 return str_util::StringReplace(op, "Min", "Max", false);
2966 }
2967 }
2968 };
2969
2970 // Replace a chain of type&shape preserving unary ops with a
2971 // '_UnaryOpsComposition' node.
2972 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
2973 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
2974 class UnaryOpsComposition : public ArithmeticOptimizerStage {
2975 public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2976 explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
2977 const ArithmeticOptimizerContext& ctx_ext)
2978 : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
2979 // WARN: This should be consistent with unary_ops_composition.cc.
2980 // clang-format off
2981 supported_ops_ = {// Ops defined via Eigen scalar ops.
2982 {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2983 {"Acos", {DT_FLOAT, DT_DOUBLE}},
2984 {"Acosh", {DT_FLOAT, DT_DOUBLE}},
2985 {"Asin", {DT_FLOAT, DT_DOUBLE}},
2986 {"Asinh", {DT_FLOAT, DT_DOUBLE}},
2987 {"Atan", {DT_FLOAT, DT_DOUBLE}},
2988 {"Atanh", {DT_FLOAT, DT_DOUBLE}},
2989 {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2990 {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2991 {"Cosh", {DT_FLOAT, DT_DOUBLE}},
2992 {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2993 {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2994 {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2995 {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2996 {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2997 {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2998 {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2999 {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3000 {"Rint", {DT_FLOAT, DT_DOUBLE}},
3001 {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3002 {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3003 {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3004 {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3005 {"Sinh", {DT_FLOAT, DT_DOUBLE}},
3006 {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3007 {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3008 {"Tan", {DT_FLOAT, DT_DOUBLE}},
3009 {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3010 // Additional ops that are not part of the Eigen.
3011 {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3012 {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3013 {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3014 {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3015 // clang-format on
3016 }
3017 ~UnaryOpsComposition() override = default;
3018
IsSupported(const NodeDef * node) const3019 bool IsSupported(const NodeDef* node) const override {
3020 return CanOptimize(*node) &&
3021 // Check that this node was not already a root of a fused chain. If
3022 // graph optimization runs twice without pruning in between,
3023 // fused_nodes_ will not have this information.
3024 !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3025 }
3026
TrySimplify(NodeDef * root,string * simplified_node_name)3027 Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3028 TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3029 DataType dtype = root->attr().at("T").type();
3030
3031 // Keep a trace of all supported input nodes that can be fused together.
3032 std::vector<string> op_nodes = {root->name()};
3033 std::vector<string> op_names = {root->op()};
3034
3035 // Check if we should follow input(0) while building an op composition.
3036 const auto predicate_fn = [&](const NodeDef& input) {
3037 if (input.name() == root->name()) return true;
3038
3039 bool follow_input_node =
3040 dtype == GetDataTypeFromAttr(input, "T") &&
3041 NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3042 CanOptimize(input);
3043
3044 if (follow_input_node) {
3045 op_nodes.push_back(input.name());
3046 op_names.push_back(input.op());
3047 }
3048
3049 return follow_input_node;
3050 };
3051
3052 NodeDef* last_op = GetTailOfChain(
3053 *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3054
3055 // We were not able to find a chain that can be replaced.
3056 if (op_names.size() == 1) return Status::OK();
3057
3058 // Do not add fused nodes to any other chain.
3059 std::for_each(op_nodes.begin(), op_nodes.end(),
3060 [this](const string& name) { AddToFusedNodes(name); });
3061
3062 // Reverse the trace to get correct composition computation order.
3063 std::reverse(op_names.begin(), op_names.end());
3064
3065 VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3066 << absl::StrJoin(op_names, ", ") << "]";
3067
3068 NodeDef* composition_node = ctx().optimized_graph->add_node();
3069 composition_node->set_name(OptimizedNodeName(*root));
3070 composition_node->set_op("_UnaryOpsComposition");
3071 composition_node->add_input(last_op->input(0));
3072 composition_node->set_device(root->device());
3073
3074 auto attr = composition_node->mutable_attr();
3075 SetAttrValue(dtype, &(*attr)["T"]);
3076 SetAttrValue(op_names, &(*attr)["op_names"]);
3077
3078 ctx().node_map->AddNode(composition_node->name(), composition_node);
3079 ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3080 composition_node->name());
3081
3082 *simplified_node_name = composition_node->name();
3083
3084 return Status::OK();
3085 }
3086
3087 private:
CanOptimize(const NodeDef & node) const3088 bool CanOptimize(const NodeDef& node) const {
3089 DataType dtype = GetDataTypeFromAttr(node, "T");
3090 if (!IsSupported(node.op(), dtype)) {
3091 return false;
3092 }
3093 if (IsInPreserveSet(node)) {
3094 return false;
3095 }
3096 if (!NodeIsOnCpu(node)) {
3097 return false;
3098 }
3099 if (NodeIsAlreadyFused(node)) {
3100 return false;
3101 }
3102 return !(IsDrivenByControlDependency(node) ||
3103 DrivesControlDependency(node));
3104 }
3105
NodeIsAlreadyFused(const NodeDef & node) const3106 bool NodeIsAlreadyFused(const NodeDef& node) const {
3107 return fused_nodes_.count(node.name()) > 0;
3108 }
3109
OptimizedNodeName(const NodeDef & node) const3110 string OptimizedNodeName(const NodeDef& node) const {
3111 return strings::StrCat(node.name(), "/unary_ops_composition");
3112 }
3113
AddToFusedNodes(const string & name)3114 void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3115
3116 // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3117 bool IsSupported(const string& op_name, DataType dtype) const {
3118 const auto it = supported_ops_.find(op_name);
3119 return it != supported_ops_.end() && it->second.count(dtype) > 0;
3120 }
3121
3122 std::unordered_map<string, std::set<DataType>> supported_ops_;
3123 std::unordered_set<string> fused_nodes_;
3124 };
3125
3126 // Replace operations of the form:
3127 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3128 // with
3129 // a_i
3130 // when the strided slice index `i` is applied in the k'th axis.
3131 //
3132 // Similarly, replace operations of the form:
3133 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3134 // with
3135 // expand_dims(a_i, axis=k)
3136 // where the slice operator can be StridedSlice or Slice.
3137 //
3138 // TODO(ebrevdo): Extend to also replace operations of the form
3139 // concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3140 // with
3141 // a_i,
3142 // when
3143 // s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3144 // and slicing is in the k'th axis.
3145 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3146 public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3147 explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3148 const ArithmeticOptimizerContext& ctx_ext)
3149 : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3150 ctx_ext) {}
3151 ~RemoveStackSliceSameAxis() override = default;
3152
IsSupported(const NodeDef * node) const3153 bool IsSupported(const NodeDef* node) const override {
3154 return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3155 }
3156
TrySimplify(NodeDef * node,string * simplified_node_name)3157 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3158 // *node is a StridedSlice NodeDef.
3159 NodeDef* pack;
3160
3161 // Get the input and see if it's a Pack op.
3162 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3163 if (!IsPack(*pack)) return Status::OK();
3164
3165 bool return_early;
3166 PartialTensorShape pack_output_shape;
3167 int pack_axis;
3168 TF_RETURN_IF_ERROR(
3169 CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3170 if (return_early) return Status::OK();
3171
3172 int64 slice_start_value;
3173 bool found;
3174 bool must_expand_dims;
3175 TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3176 &slice_start_value, &found,
3177 &must_expand_dims));
3178 if (!found) return Status::OK();
3179
3180 return RewriteGraph(node, pack, slice_start_value, pack_axis,
3181 must_expand_dims, simplified_node_name);
3182 }
3183
3184 protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3185 Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3186 PartialTensorShape* pack_output_shape, int* pack_axis,
3187 bool* return_early) {
3188 *return_early = true;
3189 TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3190
3191 *pack_axis = pack->attr().at("axis").i();
3192 auto slice_properties =
3193 ctx().graph_properties->GetInputProperties(node->name());
3194 if (slice_properties.empty() ||
3195 slice_properties[0].shape().unknown_rank()) {
3196 return Status::OK();
3197 }
3198 *pack_output_shape = slice_properties[0].shape();
3199 const int pack_output_rank = pack_output_shape->dims();
3200 if (*pack_axis < 0) {
3201 *pack_axis += pack_output_rank;
3202 }
3203 if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3204 return errors::InvalidArgument(
3205 "Pack node (", pack->name(),
3206 ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3207 }
3208 *return_early = false;
3209 return Status::OK();
3210 }
3211
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3212 Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3213 const PartialTensorShape& pack_output_shape,
3214 int pack_axis, int64* slice_start_value, bool* found,
3215 bool* must_expand_dims) {
3216 *found = false;
3217 if (IsSlice(*node)) {
3218 *must_expand_dims = true;
3219 return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3220 slice_start_value, found);
3221 } else {
3222 return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3223 slice_start_value, found, must_expand_dims);
3224 }
3225 }
3226
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found)3227 Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3228 const PartialTensorShape& pack_output_shape,
3229 int pack_axis, int64* slice_start_value,
3230 bool* found) {
3231 NodeDef* slice_begin;
3232 NodeDef* slice_size;
3233 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3234 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3235 for (const auto* n : {slice_begin, slice_size}) {
3236 if (!IsReallyConstant(*n)) return Status::OK();
3237 }
3238
3239 Tensor slice_begin_t;
3240 Tensor slice_size_t;
3241 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3242 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3243 return Status::OK();
3244 }
3245 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3246 if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3247 return Status::OK();
3248 }
3249
3250 auto copy_tensor_values_to_vector =
3251 [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3252 if (t.dtype() == DT_INT32) {
3253 auto t_flat = t.flat<int32>();
3254 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3255 } else if (t.dtype() == DT_INT64) {
3256 auto t_flat = t.flat<int64>();
3257 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3258 } else {
3259 return errors::InvalidArgument("Node ", node->name(),
3260 " has invalid type for Index attr: ",
3261 DataTypeString(t.dtype()));
3262 }
3263 return Status::OK();
3264 };
3265
3266 gtl::InlinedVector<int64, 4> slice_begin_vec;
3267 gtl::InlinedVector<int64, 4> slice_size_vec;
3268 TF_RETURN_IF_ERROR(
3269 copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3270 TF_RETURN_IF_ERROR(
3271 copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3272
3273 if (slice_begin_vec.size() != slice_size_vec.size()) {
3274 return errors::InvalidArgument("Node ", node->name(),
3275 " has mismatched lengths for begin (",
3276 slice_begin_vec.size(), ") and size (",
3277 slice_size_vec.size(), ") vectors.");
3278 }
3279 int slice_begin_vec_size = slice_begin_vec.size();
3280 if (!pack_output_shape.unknown_rank() &&
3281 slice_begin_vec_size != pack_output_shape.dims()) {
3282 return Status::OK();
3283 }
3284 if (pack_axis >= slice_begin_vec_size) {
3285 return errors::InvalidArgument(
3286 "Input to node ", node->name(), " had pack_axis ", pack_axis,
3287 " but rank was ", slice_begin_vec_size, ".");
3288 }
3289
3290 *slice_start_value = slice_begin_vec[pack_axis];
3291 if (slice_size_vec[pack_axis] != 1) {
3292 // Not slicing a single value out.
3293 return Status::OK();
3294 }
3295
3296 for (int i = 0; i < slice_begin_vec_size; ++i) {
3297 if (i != pack_axis) {
3298 if (slice_begin_vec[i] != 0 ||
3299 !(slice_size_vec[i] == -1 ||
3300 slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3301 // Not slicing on the same axis as the Pack op.
3302 return Status::OK();
3303 }
3304 }
3305 }
3306
3307 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3308 return errors::InvalidArgument(
3309 "Node ", node->name(), " requested invalid slice index ",
3310 *slice_start_value, " on axis ", pack_axis,
3311 " from tensor of shape: ", pack_output_shape.DebugString());
3312 }
3313
3314 *found = true; // slice_start_value is valid.
3315 return Status::OK();
3316 }
3317
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3318 Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3319 const PartialTensorShape& pack_output_shape,
3320 int pack_axis, int64* slice_start_value,
3321 bool* found, bool* must_expand_dims) {
3322 TF_RETURN_IF_ERROR(
3323 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3324 "new_axis_mask", "shrink_axis_mask"}));
3325
3326 const int begin_mask = node->attr().at("begin_mask").i();
3327 const int end_mask = node->attr().at("end_mask").i();
3328 const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3329 const int new_axis_mask = node->attr().at("new_axis_mask").i();
3330 const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3331
3332 // Check that the StridedSlice is one of these at pack_axis:
3333 // [..., i, ...]
3334 // [..., i:i+1, ...]
3335 // [..., :1, ...]
3336 // [..., -1:, ...]
3337 /// [..., s_{pack_axis}-1:, ...]
3338 NodeDef* slice_begin;
3339 NodeDef* slice_end;
3340 NodeDef* slice_strides;
3341 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3342 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3343 TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3344
3345 for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3346 if (!IsReallyConstant(*n)) return Status::OK();
3347 }
3348
3349 Tensor slice_begin_t;
3350 Tensor slice_end_t;
3351 Tensor slice_strides_t;
3352
3353 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3354 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3355 return Status::OK();
3356 }
3357 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3358 if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3359 return Status::OK();
3360 }
3361 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3362 if (!slice_strides_t.FromProto(
3363 slice_strides->attr().at("value").tensor())) {
3364 return Status::OK();
3365 }
3366 TensorShape processing_shape;
3367 TensorShape final_shape;
3368 bool is_identity;
3369 bool is_simple_slice;
3370 bool slice_dim0;
3371 gtl::InlinedVector<int64, 4> slice_begin_vec;
3372 gtl::InlinedVector<int64, 4> slice_end_vec;
3373 gtl::InlinedVector<int64, 4> slice_strides_vec;
3374 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3375 &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3376 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3377 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3378 &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3379
3380 if (!is_simple_slice) return Status::OK();
3381
3382 int begin_index = -1;
3383 int64 begin_value = 0;
3384 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3385 const int64 v = slice_begin_vec[i];
3386 if (v != 0) {
3387 if (begin_index != -1) {
3388 // At least two start values that are nonzero.
3389 return Status::OK();
3390 }
3391 begin_index = i;
3392 begin_value = v;
3393 }
3394 }
3395
3396 int end_index = -1;
3397 int64 end_value = 0;
3398 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3399 const int64 v = slice_end_vec[i];
3400 if (v != pack_output_shape.dim_size(i)) {
3401 if (end_index != -1) {
3402 // At least two end values that are nonzero.
3403 return Status::OK();
3404 }
3405 end_index = i;
3406 end_value = v;
3407 }
3408 }
3409
3410 if (begin_index == -1 && end_index == -1) return Status::OK();
3411 if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3412 // Somehow received different axes for begin/end slicing
3413 return Status::OK();
3414 }
3415 const int slice_axis = (begin_index == -1) ? end_index : begin_index;
3416 if (slice_axis != pack_axis) {
3417 // Not slicing on the same axis as the Pack op.
3418 return Status::OK();
3419 }
3420 *slice_start_value = (begin_index == -1) ? 0 : begin_value;
3421 const int64 slice_end_value =
3422 (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
3423 if (slice_end_value != *slice_start_value + 1) {
3424 // Not slicing a single value out.
3425 return Status::OK();
3426 }
3427
3428 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3429 return errors::InvalidArgument(
3430 "Node ", node->name(), " requested invalid slice index ",
3431 *slice_start_value, " on axis ", slice_axis,
3432 " from tensor of shape: ", pack_output_shape.DebugString());
3433 }
3434
3435 if (shrink_axis_mask == 0) {
3436 *must_expand_dims = true;
3437 } else if (shrink_axis_mask == (1 << slice_axis)) {
3438 *must_expand_dims = false;
3439 } else {
3440 // Shrinking on a different axis from the one that we are slicing on.
3441 return Status::OK();
3442 }
3443
3444 *found = true; // slice_start_value is valid.
3445 return Status::OK();
3446 }
3447
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64 slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)3448 Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
3449 int64 slice_start_value, int pack_axis,
3450 bool must_expand_dims, string* simplified_node_name) {
3451 const string& input_slice = pack->input(slice_start_value);
3452
3453 const OpInfo::TensorProperties* input_slice_properties;
3454 TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
3455 &input_slice_properties));
3456 PartialTensorShape input_slice_shape(input_slice_properties->shape());
3457
3458 const OpInfo::TensorProperties* output_properties;
3459 TF_RETURN_IF_ERROR(GetTensorProperties(
3460 strings::StrCat(node->name(), ":", 0), &output_properties));
3461 PartialTensorShape output_shape(output_properties->shape());
3462 NodeDef* output =
3463 AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
3464 if (!must_expand_dims) {
3465 output->set_op("Identity");
3466 output->set_device(node->device());
3467 SetDataTypeToAttr(output_properties->dtype(), "T", output);
3468 output->add_input(input_slice);
3469 } else {
3470 NodeDef* axis = AddEmptyNode(
3471 OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
3472 axis->set_op("Const");
3473 axis->set_device(node->device());
3474 // We need to add a control edge from input slice to guarantee that axis
3475 // constant will be executed in the same frame as `input_slice`, otherwise
3476 // ExpandDims might have mismatched input frames.
3477 axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
3478 auto axis_attr = axis->mutable_attr();
3479 SetDataTypeToAttr(DT_INT32, "dtype", axis);
3480 auto* axis_t = (*axis_attr)["value"].mutable_tensor();
3481 axis_t->set_dtype(DT_INT32);
3482 axis_t->add_int_val(pack_axis);
3483 AddToOptimizationQueue(axis);
3484 output->set_op("ExpandDims");
3485 output->set_device(node->device());
3486 SetDataTypeToAttr(output_properties->dtype(), "T", output);
3487 SetDataTypeToAttr(DT_INT32, "Tdim", output);
3488 output->add_input(input_slice);
3489 output->add_input(axis->name());
3490 }
3491
3492 // Copy dependencies over.
3493 ForwardControlDependencies(output, {node, pack});
3494 AddToOptimizationQueue(output);
3495 *simplified_node_name = output->name();
3496
3497 return Status::OK();
3498 }
3499 };
3500
3501 // Eliminates unnecessary copies during sparse embedding lookup operations.
3502 //
3503 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
3504 // generates code of the form:
3505 //
3506 // embeddings = <a 2D Tensor>
3507 // sparse_ids = <a tf.int64 SparseTensor>
3508 // segment_ids = sparse_ids.indices[:, 0]
3509 // ids, idx = tf.unique(sparse_ids.values)
3510 // gathered_rows = tf.gather(params, ids)
3511 // result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
3512 //
3513 // In this case, all of the work in `tf.unique()` and `tf.gather()`
3514 // can be avoided by passing the full embeddings to
3515 // `tf.sparse.segment_<combiner>()` and performing the same amount of
3516 // computation (but fewer copies and allocations) as follows:
3517 //
3518 // embeddings = <a 2D Tensor>
3519 // sparse_ids = <a tf.int64 SparseTensor>
3520 // segment_ids = sparse_ids.indices[:, 0]
3521 // result = tf.sparse.segment_<combiner>(
3522 // embeddings, sparse_ids.values, segment_ids)
3523 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
3524 public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3525 explicit SimplifyEmbeddingLookupStage(
3526 const GraphOptimizerContext& ctx,
3527 const ArithmeticOptimizerContext& ctx_ext)
3528 : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
3529 }
3530 ~SimplifyEmbeddingLookupStage() override = default;
3531
IsSupported(const NodeDef * node) const3532 bool IsSupported(const NodeDef* node) const override {
3533 return IsAnySparseSegmentReduction(*node);
3534 }
3535
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3536 Status TrySimplify(NodeDef* reduction_node,
3537 string* simplified_node_name) override {
3538 if (IsInPreserveSet(*reduction_node)) return Status::OK();
3539
3540 // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
3541 // axis.
3542 NodeDef* gather_node = nullptr;
3543 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
3544 if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
3545 gather_node->device() != reduction_node->device())
3546 return Status::OK();
3547 if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
3548 return Status::OK();
3549
3550 // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
3551 // axis.
3552 NodeDef* unique_node = nullptr;
3553 TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
3554 if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
3555 unique_node->device() != gather_node->device())
3556 return Status::OK();
3557 if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
3558 return Status::OK();
3559
3560 DataType unique_element_type;
3561 TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
3562
3563 // Input 1 (indices) of the reduction node must be output 1 of the unique
3564 // node.
3565 const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
3566 if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
3567
3568 // Input 0 (data) of the reduction node becomes input 1 (params) of the
3569 // gather node.
3570 reduction_node->set_input(0, gather_node->input(0));
3571 ctx().node_map->UpdateInput(reduction_node->name(),
3572 reduction_node->input(0),
3573 gather_node->input(0));
3574
3575 // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
3576 // node.
3577 reduction_node->set_input(1, unique_node->input(0));
3578 ctx().node_map->UpdateInput(reduction_node->name(),
3579 reduction_node->input(1),
3580 unique_node->input(0));
3581 SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
3582
3583 *simplified_node_name = reduction_node->name();
3584 return Status::OK();
3585 }
3586
3587 private:
IsAxis0(const NodeDef & node,int axis_input)3588 bool IsAxis0(const NodeDef& node, int axis_input) {
3589 Tensor axis_tensor;
3590 if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
3591 return false;
3592 if (axis_tensor.NumElements() != 1) return false;
3593 if (axis_tensor.dtype() == DT_INT32) {
3594 return axis_tensor.flat<int32>()(0) == 0;
3595 } else if (axis_tensor.dtype() == DT_INT64) {
3596 return axis_tensor.flat<int64>()(0) == 0;
3597 } else {
3598 return false;
3599 }
3600 }
3601 };
3602
3603 // Eliminates unnecessary casts before sparse segment reduction operations.
3604 //
3605 // Existing graphs and library code would often insert a cast from DT_INT64 to
3606 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
3607 // Support for for DT_INT64 indices and/or segment_ids now exists, so we can
3608 // pass the input directly without a cast.
3609 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
3610 public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3611 explicit RemoveCastIntoSegmentReductionStage(
3612 const GraphOptimizerContext& ctx,
3613 const ArithmeticOptimizerContext& ctx_ext)
3614 : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
3615 ctx_ext) {}
3616 ~RemoveCastIntoSegmentReductionStage() override = default;
3617
IsSupported(const NodeDef * node) const3618 bool IsSupported(const NodeDef* node) const override {
3619 return IsAnySparseSegmentReduction(*node);
3620 }
3621
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3622 Status TrySimplify(NodeDef* reduction_node,
3623 string* simplified_node_name) override {
3624 if (IsInPreserveSet(*reduction_node)) return Status::OK();
3625
3626 bool optimized = false;
3627
3628 // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
3629 // DT_INT64.
3630 std::array<std::pair<int, string>, 2> input_details = {
3631 std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
3632
3633 for (const auto& input : input_details) {
3634 int input_index = input.first;
3635 const string& type_attr_name = input.second;
3636 NodeDef* cast_node = nullptr;
3637 TF_RETURN_IF_ERROR(
3638 GetInputNode(reduction_node->input(input_index), &cast_node));
3639 DataType original_index_type;
3640 if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
3641 reduction_node->set_input(input_index, cast_node->input(0));
3642 ctx().node_map->UpdateInput(reduction_node->name(),
3643 reduction_node->input(1),
3644 cast_node->input(0));
3645 SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
3646 optimized = true;
3647 }
3648 }
3649
3650 if (optimized) *simplified_node_name = reduction_node->name();
3651 return Status::OK();
3652 }
3653
3654 private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)3655 bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
3656 if (!IsCast(node)) return false;
3657 if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
3658 return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
3659 }
3660 };
3661
3662 } // namespace
3663
SimplifyArithmeticOps(bool can_use_shapes)3664 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
3665 SetVector<NodeDef*> nodes_to_simplify;
3666 nodes_to_simplify.Reserve(optimized_graph_->node_size());
3667 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3668 nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
3669 }
3670
3671 const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
3672 graph_properties_.get(), node_map_.get(),
3673 &feed_nodes_, opt_level_);
3674 const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
3675
3676 // Stop pipeline after first stage returning non-empty simplified tensor
3677 // name.
3678 const auto stop = [](const string& result) { return !result.empty(); };
3679 GraphOptimizerStagePipeline<string> pipeline(stop);
3680
3681 if (options_.combine_add_to_addn && can_use_shapes)
3682 pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
3683 if (options_.fold_conjugate_into_transpose)
3684 pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
3685 if (options_.fold_multiply_into_conv)
3686 pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
3687 if (options_.fold_transpose_into_matmul)
3688 pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
3689 if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
3690 pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
3691 if (options_.minimize_broadcasts && can_use_shapes)
3692 pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
3693 if (options_.remove_identity_transpose && can_use_shapes)
3694 pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
3695 if (options_.remove_involution)
3696 pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
3697 if (options_.remove_redundant_bitcast)
3698 pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
3699 if (options_.remove_redundant_cast)
3700 pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
3701 if (options_.remove_redundant_reshape)
3702 pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
3703 if (options_.remove_negation)
3704 pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
3705 if (options_.replace_mul_with_square)
3706 pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
3707 if (options_.remove_logical_not)
3708 pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
3709 if (options_.reorder_cast_like_and_value_preserving)
3710 pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
3711 if (options_.simplify_aggregation)
3712 pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
3713 if (options_.hoist_cwise_unary_chains)
3714 pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
3715 if (options_.convert_sqrt_div_to_rsqrt_mul)
3716 pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
3717 if (options_.remove_idempotent)
3718 pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
3719 if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
3720 if (options_.convert_log1p)
3721 pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
3722 if (options_.convert_log_softmax)
3723 pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
3724 if (options_.optimize_max_or_min_of_monotonic)
3725 pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
3726 if (options_.convert_expm1)
3727 pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
3728 if (options_.unary_ops_composition)
3729 pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
3730 if (options_.remove_stack_slice_same_axis)
3731 pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
3732 if (options_.simplify_embedding_lookup)
3733 pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
3734 if (options_.remove_cast_into_segment_reduction)
3735 pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
3736 if (options_.fuse_squared_diff)
3737 pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
3738
3739 VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
3740 << absl::StrJoin(pipeline.StageNames(), ", ");
3741
3742 while (!nodes_to_simplify.Empty()) {
3743 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3744 NodeDef* node = nodes_to_simplify.PopBack();
3745
3746 string simplified_tensor = "";
3747 bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
3748
3749 // If the node was not optimized by any of the stages, go to the next one.
3750 if (!optimized) continue;
3751
3752 // re-wire consumers of an old node to the new one
3753 if (NodeName(simplified_tensor) != node->name()) {
3754 // Always consider simplified_tensor for further optimizations.
3755 NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
3756 if (simplified_node != nullptr) {
3757 nodes_to_simplify.PushBack(simplified_node);
3758 }
3759 // When `node` is simplified to another node rather than in-place, the
3760 // consumers of `node` are already redirected to `simplified_tensor`.
3761 // Re-push the consumers into `nodes_to_simplify` for further
3762 // optimizations.
3763 const std::vector<NodeDef*> consumers =
3764 node_map_->GetOutputsOrderedByNodeName(node->name());
3765 for (NodeDef* consumer : consumers) {
3766 // Update `consumer`'s use of `node` to `input`'s operand.
3767 for (int i = 0; i < consumer->input_size(); ++i) {
3768 int operand_pos;
3769 string operand_node_name =
3770 ParseNodeName(consumer->input(i), &operand_pos);
3771 if (operand_node_name == node->name()) {
3772 *consumer->mutable_input(i) =
3773 (operand_pos < 0
3774 ? AsControlDependency(NodeName(simplified_tensor))
3775 : simplified_tensor);
3776 }
3777 }
3778 node_map_->UpdateInput(consumer->name(), node->name(),
3779 simplified_tensor);
3780 nodes_to_simplify.PushBack(consumer);
3781 }
3782 }
3783 }
3784 return Status::OK();
3785 }
3786
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)3787 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
3788 const GrapplerItem& item,
3789 GraphDef* optimized_graph) {
3790 // Set up helper data structures.
3791 nodes_to_preserve_ = item.NodesToPreserve();
3792 fetch_nodes_known_ = !item.fetch.empty();
3793 GrapplerItem optimized_item(item);
3794 optimized_graph_ = &optimized_item.graph;
3795
3796 node_map_.reset(new NodeMap(optimized_graph_));
3797 for (const auto& feed : item.feed) {
3798 feed_nodes_.insert(NodeName(feed.first));
3799 }
3800
3801 // // Disable restricted graph rewrites.
3802 options_.unary_ops_composition &=
3803 item.optimization_options().allow_non_differentiable_rewrites;
3804
3805 // Perform topological sort on the graph in order to help DedupComputations
3806 // and AddOpsRewrite to optimize larger subgraphs starting from the roots
3807 // with more inputs.
3808 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
3809 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3810
3811 graph_properties_.reset(new GraphProperties(optimized_item));
3812 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3813 const Status status =
3814 graph_properties_->InferStatically(assume_valid_feeds,
3815 /*aggressive_shape_inference=*/false,
3816 /*include_tensor_values=*/false);
3817 const bool can_use_shapes = status.ok();
3818 if (!can_use_shapes) {
3819 VLOG(1) << "Shape inference failed." << status.error_message();
3820 }
3821
3822 // Perform the optimizations.
3823 TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
3824
3825 optimized_graph->Swap(optimized_graph_);
3826 return Status::OK();
3827 }
3828
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)3829 void ArithmeticOptimizer::Feedback(Cluster* /*cluster*/,
3830 const GrapplerItem& /*item*/,
3831 const GraphDef& /*optimized_graph*/,
3832 double /*result*/) {
3833 // Nothing to do for ArithmeticOptimizer.
3834 }
3835
3836 } // namespace grappler
3837 } // namespace tensorflow
3838