1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19 
20 #include <cmath>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/grappler/clusters/cluster.h"
36 #include "tensorflow/core/grappler/costs/graph_properties.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/op_types.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils.h"
41 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/cleanup.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47 #include "tensorflow/core/lib/strings/strcat.h"
48 #include "tensorflow/core/platform/cpu_info.h"
49 #include "tensorflow/core/platform/denormal.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/setround.h"
52 #include "tensorflow/core/platform/tensor_coding.h"
53 #include "tensorflow/core/public/version.h"
54 #include "tensorflow/core/util/bcast.h"
55 #include "tensorflow/core/util/saved_tensor_slice_util.h"
56 
57 namespace tensorflow {
58 namespace grappler {
59 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
60 
61 namespace {
62 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
63  public:
EigenThreadPoolWrapper(thread::ThreadPool * pool)64   explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper()65   ~EigenThreadPoolWrapper() override {}
Schedule(std::function<void ()> fn)66   void Schedule(std::function<void()> fn) override {
67     auto wrapped = [=]() {
68       // TensorFlow flushes denormals to zero and rounds to nearest, so we do
69       // the same here.
70       port::ScopedFlushDenormal flush;
71       port::ScopedSetRound round(FE_TONEAREST);
72       fn();
73     };
74     pool_->Schedule(std::move(wrapped));
75   }
NumThreads() const76   int NumThreads() const override { return pool_->NumThreads(); }
CurrentThreadId() const77   int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
78 
79  private:
80   thread::ThreadPool* pool_ = nullptr;
81 };
82 
83 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)84 bool AllValuesAre(const TensorProto& proto, const T& value) {
85   Tensor tensor;
86   if (!tensor.FromProto(proto)) {
87     return false;
88   }
89   auto values = tensor.flat<T>();
90   for (int i = 0; i < tensor.NumElements(); ++i) {
91     if (values(i) != value) {
92       return false;
93     }
94   }
95   return true;
96 }
97 
98 // Add new_input as a control input to node if it does not already depend on it.
99 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
100 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)101 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
102                           GraphDef* graph, NodeMap* node_map) {
103   bool already_exists = false;
104   for (const string& input : node->input()) {
105     if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
106       already_exists = true;
107       break;
108     }
109   }
110   if (!already_exists) {
111     const string ctrl_dep =
112         ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
113     node->add_input(ctrl_dep);
114     node_map->AddOutput(NodeName(ctrl_input), node->name());
115   }
116   return !already_exists;
117 }
118 
119 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)120 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
121                              GraphDef* graph, NodeMap* node_map) {
122   bool removed_input = false;
123   bool update_node_map = true;
124   const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
125   for (int i = 0; i < node->input_size(); ++i) {
126     const string& input = node->input(i);
127     if (old_input_ctrl_dep == input) {
128       if (IsControlInput(input)) {
129         node->mutable_input()->SwapElements(i, node->input_size() - 1);
130         node->mutable_input()->RemoveLast();
131         removed_input = true;
132       } else {
133         // There is a non-control input from the same node.
134         // Don't remove the output from the NodeMap.
135         update_node_map = false;
136       }
137     }
138   }
139   if (update_node_map) {
140     node_map->RemoveOutput(NodeName(old_input), node->name());
141   }
142   return removed_input;
143 }
144 
GetConcatAxis(const GraphProperties & properties,NodeDef * node,int * axis)145 bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
146                    int* axis) {
147   if (node->op() != "ConcatV2" ||
148       properties.GetInputProperties(node->name()).empty()) {
149     return false;
150   }
151   const auto& axis_input = properties.GetInputProperties(node->name()).back();
152   if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
153     return false;
154   }
155 
156   Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
157   if (!axis_tensor.FromProto(axis_input.value())) {
158     return false;
159   }
160   *axis = axis_input.dtype() == DT_INT64
161               ? static_cast<int>(axis_tensor.scalar<int64>()())
162               : axis_tensor.scalar<int32>()();
163   return true;
164 }
165 
HasTPUAttributes(const NodeDef & node)166 bool HasTPUAttributes(const NodeDef& node) {
167   AttrSlice attrs(node);
168   for (auto attr : attrs) {
169     if (attr.first.find("_tpu_") != attr.first.npos) {
170       return true;
171     }
172   }
173   return false;
174 }
175 
176 template <typename T>
PackedValuesNotEqual(T a,T b)177 bool PackedValuesNotEqual(T a, T b) {
178   return a != b;
179 }
180 
181 template <>
PackedValuesNotEqual(float a,float b)182 bool PackedValuesNotEqual(float a, float b) {
183   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
184 }
185 
186 template <>
PackedValuesNotEqual(double a,double b)187 bool PackedValuesNotEqual(double a, double b) {
188   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
189 }
190 
QuantizedTypeMinAsFloat(DataType data_type)191 float QuantizedTypeMinAsFloat(DataType data_type) {
192   switch (data_type) {
193     case DT_QINT8:
194       return Eigen::NumTraits<qint8>::lowest();
195     case DT_QUINT8:
196       return Eigen::NumTraits<quint8>::lowest();
197     case DT_QINT16:
198       return Eigen::NumTraits<qint16>::lowest();
199     case DT_QUINT16:
200       return Eigen::NumTraits<quint16>::lowest();
201     case DT_QINT32:
202       return Eigen::NumTraits<qint32>::lowest();
203     default:
204       return 0.0f;
205   }
206 }
207 
QuantizedTypeMaxAsFloat(DataType data_type)208 float QuantizedTypeMaxAsFloat(DataType data_type) {
209   switch (data_type) {
210     case DT_QINT8:
211       return Eigen::NumTraits<qint8>::highest();
212     case DT_QUINT8:
213       return Eigen::NumTraits<quint8>::highest();
214     case DT_QINT16:
215       return Eigen::NumTraits<qint16>::highest();
216     case DT_QUINT16:
217       return Eigen::NumTraits<quint16>::highest();
218     case DT_QINT32:
219       return Eigen::NumTraits<qint32>::highest();
220     default:
221       return 0.0f;
222   }
223 }
224 
225 }  // namespace
226 
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)227 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
228                                  DeviceBase* cpu_device)
229     : opt_level_(opt_level), cpu_device_(cpu_device) {
230   resource_mgr_.reset(new ResourceMgr());
231 }
232 
ConstantFolding(DeviceBase * cpu_device)233 ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
234     : ConstantFolding(RewriterConfig::ON, cpu_device) {}
235 
236 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)237 string ConstantFolding::AddControlDependency(const string& input_name,
238                                              GraphDef* graph,
239                                              NodeMap* node_map) {
240   if (IsControlInput(input_name)) {
241     return input_name;
242   }
243   const NodeDef* node = node_map->GetNode(input_name);
244   if (!IsSwitch(*node)) {
245     return AsControlDependency(*node);
246   } else {
247     // We can't anchor control dependencies directly on the switch node: unlike
248     // other nodes only one of the outputs of the switch node will be generated
249     // when the switch node is executed, and we need to make sure the control
250     // dependency is only triggered when the corresponding output is triggered.
251     // We start by looking for an identity node connected to the output of the
252     // switch node, and use it to anchor the control dependency.
253     auto outputs = node_map->GetOutputs(node->name());
254     for (const NodeDef* output : outputs) {
255       if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
256         if (IsSameInput(node->input(0), input_name)) {
257           return AsControlDependency(*output);
258         }
259       }
260     }
261     // We haven't found an existing node where we can anchor the control
262     // dependency: add a new identity node.
263     int port = 0;
264     string ctrl_dep_name = ParseNodeName(input_name, &port);
265     strings::StrAppend(&ctrl_dep_name, "_", port);
266     ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
267     const DataType output_type = node->attr().at("T").type();
268 
269     NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
270     if (added_node == nullptr) {
271       added_node = graph->add_node();
272       added_node->set_name(ctrl_dep_name);
273       added_node->set_op("Identity");
274       added_node->set_device(node->device());
275 
276       (*added_node->mutable_attr())["T"].set_type(output_type);
277       *added_node->add_input() = input_name;
278       node_map->AddNode(added_node->name(), added_node);
279       node_map->AddOutput(node->name(), added_node->name());
280     }
281     return AsControlDependency(*added_node);
282   }
283 }
284 
285 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64 value,const DataType & type,const int index,Tensor * tensor)286 static Status PutValueIntoTensor(const int64 value, const DataType& type,
287                                  const int index, Tensor* tensor) {
288   if (type == DT_INT32) {
289     if (value >= INT_MAX) {
290       return Status(error::INVALID_ARGUMENT, "int32 overflow");
291     }
292     tensor->flat<int32>()(index) = static_cast<int32>(value);
293   } else {
294     tensor->flat<int64>()(index) = value;
295   }
296   return Status::OK();
297 }
298 
299 // Writes the given tensor shape into the given tensor.
300 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)301 static Status ConvertShapeToConstant(const string& op, const DataType& type,
302                                      const PartialTensorShape& shp,
303                                      Tensor* tensor) {
304   if (op == "Shape" || op == "ShapeN") {
305     *tensor = Tensor(type, TensorShape({shp.dims()}));
306     for (int i = 0; i < shp.dims(); ++i) {
307       TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
308     }
309   } else if (op == "Size") {
310     int64 size = 1;
311     for (int i = 0; i < shp.dims(); ++i) {
312       size *= shp.dim_size(i);
313     }
314     *tensor = Tensor(type, TensorShape({}));
315     TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
316   } else {
317     CHECK_EQ(op, "Rank");
318     *tensor = Tensor(type, TensorShape({}));
319     TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
320   }
321   return Status::OK();
322 }
323 
324 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const325 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
326                                           StringPiece suffix) const {
327   return node_map_->NodeExists(OptimizedNodeName(node, suffix));
328 }
329 
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const330 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
331                                           StringPiece suffix) const {
332   return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
333                              kConstantFoldingConst);
334 }
335 
IsReallyConstant(const NodeDef & node) const336 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
337   if (!IsConstant(node)) {
338     return false;
339   }
340   // If the node is fed it's not constant anymore.
341   return feed_nodes_.find(node.name()) == feed_nodes_.end();
342 }
343 
344 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)345 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
346   // We may add some nodes to the graph to encode control dependencies and hold
347   // the materialized shapes: there is no need to process these added nodes, so
348   // only iterate over the nodes of the input graph.
349   const int node_count = graph_->node_size();
350   for (int node_idx = 0; node_idx < node_count; ++node_idx) {
351     NodeDef* node = graph_->mutable_node(node_idx);
352     const string op = node->op();
353     if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
354         op != "TensorArraySizeV3") {
355       continue;
356     }
357 
358     const std::vector<OpInfo::TensorProperties>& output =
359         properties.GetOutputProperties(node->name());
360     const std::vector<OpInfo::TensorProperties>& input =
361         properties.GetInputProperties(node->name());
362     if (input.empty() || output.empty()) {
363       continue;
364     }
365 
366     if (op == "Shape" || op == "Size" || op == "Rank") {
367       CHECK_EQ(1, output.size());
368       CHECK_EQ(1, input.size());
369 
370       const DataType type = output[0].dtype();
371       CHECK(type == DT_INT32 || type == DT_INT64);
372       const PartialTensorShape shape(input[0].shape());
373 
374       if ((op != "Rank" && !shape.IsFullyDefined()) ||
375           (op == "Rank" && shape.unknown_rank())) {
376         continue;
377       }
378 
379       Tensor constant_value(type);
380       if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
381         continue;
382       }
383 
384       // Repurpose the existing node to be the constant.
385       // Device placement is preserved.
386       node->set_op("Const");
387       node->clear_attr();
388       (*node->mutable_attr())["dtype"].set_type(type);
389       constant_value.AsProtoTensorContent(
390           (*node->mutable_attr())["value"].mutable_tensor());
391 
392       // Turn the data input into a control dependency: this is needed to
393       // ensure that the constant value will only be run in the
394       // cases where the shape/rank/size would have been run in
395       // the original graph.
396       string ctrl_dep =
397           AddControlDependency(node->input(0), graph_, node_map_.get());
398       node->set_input(0, ctrl_dep);
399       node_map_->AddOutput(NodeName(ctrl_dep), node->name());
400 
401       // Done with the Shape/Size/Rank node, move to the next node.
402       continue;
403     }
404 
405     if (op == "TensorArraySizeV3") {
406       const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
407       if (array->input_size() == 0 ||
408           (array->attr().count("dynamic_size") != 0 &&
409            array->attr().at("dynamic_size").b())) {
410         continue;
411       }
412       const NodeDef* array_size =
413           CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
414       if (IsReallyConstant(*array_size)) {
415         // Don't materialize 0 sizes to avoid triggering incorrect static
416         // checks. A 0 sized array that can't grow isn't useful anyway.
417         if (array_size->attr().count("value") == 0) {
418           continue;
419         }
420         const TensorProto& raw_val = array_size->attr().at("value").tensor();
421         if (raw_val.dtype() != DT_INT32) {
422           continue;
423         }
424         Tensor value(raw_val.dtype(), raw_val.tensor_shape());
425         if (!value.FromProto(raw_val)) {
426           continue;
427         }
428         if (value.flat<int32>()(0) == 0) {
429           continue;
430         }
431 
432         node->set_op("Const");
433         *node->mutable_attr() = array_size->attr();
434         node->set_input(0, AsControlDependency(NodeName(node->input(0))));
435         node->set_input(1, AddControlDependency(NodeName(node->input(1)),
436                                                 graph_, node_map_.get()));
437       }
438       continue;
439     }
440 
441     // Handle ShapeN materialization case.
442     // It's possible that not all input tensors have known shapes.
443     CHECK_EQ(op, "ShapeN");
444     CHECK_EQ(input.size(), output.size());
445     const NodeDef* const shape_n_node = node;
446     for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
447       const DataType type = output[port_idx].dtype();
448       CHECK(type == DT_INT32 || type == DT_INT64);
449       const PartialTensorShape shape(input[port_idx].shape());
450       if (!shape.IsFullyDefined()) {
451         continue;
452       }
453       Tensor constant_value(type);
454       auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
455       if (!status.ok()) {
456         continue;
457       }
458 
459       // Find all nodes consuming this shape and connect them through the new
460       // constant node instead.
461       auto outputs = node_map_->GetOutputs(shape_n_node->name());
462       for (NodeDef* output : outputs) {
463         // Track whether there are any direct edges left between shape_n_node
464         // and this output node after the transformation.
465         bool direct_edges_exist = false;
466         for (int k = 0; k < output->input_size(); ++k) {
467           int port;
468           const string node_name = ParseNodeName(output->input(k), &port);
469           if (node_name == shape_n_node->name() && port == port_idx) {
470             // Create a const node as ShapeN's output if not already.
471             const string const_name = OptimizedNodeName(
472                 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
473             if (node_map_->GetNode(const_name) == nullptr) {
474               NodeDef* added_node = graph_->add_node();
475               added_node->set_name(const_name);
476               added_node->set_op("Const");
477               added_node->set_device(shape_n_node->device());
478               node_map_->AddNode(added_node->name(), added_node);
479               (*added_node->mutable_attr())["dtype"].set_type(type);
480               constant_value.AsProtoTensorContent(
481                   (*added_node->mutable_attr())["value"].mutable_tensor());
482               // We add a control dependency to the original ShapeN node,
483               // so that the node will only be run if all inputs of the
484               // original ShapeN node are run.
485               string ctrl_dep = AddControlDependency(shape_n_node->name(),
486                                                      graph_, node_map_.get());
487               *added_node->add_input() = ctrl_dep;
488               node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
489             }
490             *output->mutable_input(k) = const_name;
491             node_map_->AddOutput(const_name, output->name());
492           }
493           if (node_name == shape_n_node->name() && port != port_idx) {
494             direct_edges_exist = true;
495           }
496         }
497         if (!direct_edges_exist) {
498           node_map_->RemoveOutput(node->name(), output->name());
499         }
500       }
501     }
502   }
503 
504   return Status::OK();
505 }
506 
507 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)508 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
509                   BCast::Vec* shape, int64* min_id) {
510   if (shape_node.op() == "Shape") {
511     const std::vector<OpInfo::TensorProperties>& prop1 =
512         properties.GetInputProperties(shape_node.name());
513     if (prop1.size() != 1) {
514       return false;
515     }
516     const TensorShapeProto& shp = prop1[0].shape();
517     if (shp.unknown_rank()) {
518       return false;
519     }
520     for (const auto& dim : shp.dim()) {
521       shape->push_back(dim.size());
522       *min_id = std::min<int64>(*min_id, dim.size());
523     }
524   } else {
525     if (shape_node.attr().count("value") == 0) {
526       return false;
527     }
528     const TensorProto& raw_val = shape_node.attr().at("value").tensor();
529     if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
530       return false;
531     }
532     Tensor value(raw_val.dtype(), raw_val.tensor_shape());
533     if (!value.FromProto(raw_val)) {
534       return false;
535     }
536     for (int j = 0; j < value.NumElements(); ++j) {
537       if (raw_val.dtype() == DT_INT64) {
538         shape->push_back(value.vec<int64>()(j));
539       } else {
540         shape->push_back(value.vec<int>()(j));
541       }
542     }
543   }
544   return true;
545 }
546 }  // namespace
547 
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)548 Status ConstantFolding::MaterializeBroadcastGradientArgs(
549     const NodeDef& node, const GraphProperties& properties) {
550   const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
551   const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
552   if (shape_node1 == nullptr ||
553       (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
554       shape_node2 == nullptr ||
555       (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
556     return Status::OK();
557   }
558 
559   // Don't optimize this again if it was already optimized and folded.
560   if (OptimizedNodeExists(node, "-folded-1") ||
561       OptimizedNodeExists(node, "-folded-2")) {
562     return Status::OK();
563   }
564   int64 min_id = 0;
565   BCast::Vec shape1;
566   if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
567     return Status::OK();
568   }
569   BCast::Vec shape2;
570   if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
571     return Status::OK();
572   }
573   // A value of -1 means we don't known anything about the dimension. Replace
574   // the -1 values with unique dimension ids since we don't want two '-1'
575   // dimensions to be considered equal.
576   for (auto& id : shape1) {
577     if (id == -1) {
578       id = --min_id;
579     }
580   }
581   for (auto& id : shape2) {
582     if (id == -1) {
583       id = --min_id;
584     }
585   }
586 
587   // Beware: the reduction dimensions computed by the BCast class are valid iff
588   // we assume that two distinct symbolic dimensions can't be equal and a
589   // symbolic dimension can't be equal to 1. This is often but not always true,
590   // so to make this optimization safe we filter out these cases.
591   const int common_dims = std::min(shape1.size(), shape2.size());
592   for (int i = 0; i < common_dims; ++i) {
593     if (shape1[i] >= 0 && shape2[i] >= 0) {
594       continue;
595     }
596     if (shape1[i] != shape2[i]) {
597       // We're either dealing with 2 different symbolic dimensions or a symbolic
598       // and a know dimensions. We can't be sure whether both are equal or not,
599       // so we can't be sure whether we'll be broadcasting or not.
600       return Status::OK();
601     }
602   }
603   // These extra dims could be equal to 1, in which case there is no
604   // broadcasting. It could also be greater than 1, in which case there would
605   // be broadcasting. Since we don't know, we'll just punt.
606   for (int i = common_dims; i < shape1.size(); ++i) {
607     if (shape1[i] < 0) {
608       return Status::OK();
609     }
610   }
611   for (int i = common_dims; i < shape2.size(); ++i) {
612     if (shape2[i] < 0) {
613       return Status::OK();
614     }
615   }
616 
617   BCast bcast(shape1, shape2);
618   if (!bcast.IsValid()) {
619     return Status::OK();
620   }
621 
622   BCast::Vec reduce_dims[2];
623   reduce_dims[0] = bcast.grad_x_reduce_idx();
624   reduce_dims[1] = bcast.grad_y_reduce_idx();
625 
626   TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
627   const DataType type = node.attr().at("T").type();
628   NodeDef* out[2];
629   for (int j = 0; j < 2; ++j) {
630     int reduction_indices = reduce_dims[j].size();
631     Tensor value(type, TensorShape({reduction_indices}));
632     for (int i = 0; i < reduction_indices; ++i) {
633       if (type == DT_INT32) {
634         value.vec<int32>()(i) = reduce_dims[j][i];
635       } else {
636         value.vec<int64>()(i) = reduce_dims[j][i];
637       }
638     }
639     string const_name =
640         OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
641     out[j] = node_map_->GetNode(const_name);
642     if (out[j] == nullptr) {
643       out[j] = graph_->add_node();
644       TF_RETURN_IF_ERROR(
645           CreateNodeDef(const_name, TensorValue(&value), out[j]));
646       out[j]->set_device(node.device());
647       node_map_->AddNode(const_name, out[j]);
648       string ctrl_dep =
649           AddControlDependency(node.name(), graph_, node_map_.get());
650       *out[j]->add_input() = ctrl_dep;
651       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
652     }
653   }
654 
655   const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
656   for (NodeDef* output : outputs) {
657     for (int k = 0; k < output->input_size(); ++k) {
658       int port;
659       string node_name = ParseNodeName(output->input(k), &port);
660       if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
661         *output->mutable_input(k) = out[port]->name();
662         node_map_->UpdateInput(output->name(), node_name, out[port]->name());
663       }
664     }
665   }
666 
667   return Status::OK();
668 }
669 
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)670 Status ConstantFolding::MaterializeReductionIndices(
671     NodeDef* node, const GraphProperties& properties) {
672   if (node->input_size() < 2) {
673     return Status::OK();
674   }
675   const NodeDef* indices = node_map_->GetNode(node->input(1));
676   if (!indices || IsReallyConstant(*indices)) {
677     // The reduction indices are already constant, there's nothing to do.
678     return Status::OK();
679   }
680 
681   const std::vector<OpInfo::TensorProperties>& input_props =
682       properties.GetInputProperties(node->name());
683   if (input_props.size() != 2) {
684     return Status::OK();
685   }
686   const OpInfo::TensorProperties& input_prop = input_props[0];
687   if (input_prop.shape().unknown_rank()) {
688     // We can't do anything if we don't know the rank of the input.
689     return Status::OK();
690   }
691   const int input_rank = input_prop.shape().dim_size();
692   if (input_rank < 1) {
693     // Unexpected graph, don't try to change it.
694     return Status::OK();
695   }
696   const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
697   DataType dtype = reduction_indices_prop.dtype();
698   if (dtype != DT_INT32 && dtype != DT_INT64) {
699     return Status::OK();
700   }
701   PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
702   const int num_reduction_indices = reduction_indices_shape.num_elements();
703 
704   const std::vector<OpInfo::TensorProperties>& output_props =
705       properties.GetOutputProperties(node->name());
706   if (output_props.size() != 1) {
707     return Status::OK();
708   }
709   const OpInfo::TensorProperties& output_prop = output_props[0];
710   const int output_rank =
711       output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
712 
713   bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
714   if (!full_reduction) {
715     // A full reduction will generate a tensor of one of the shapes
716     // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
717     // elements in the output of the reduction, we may deduce it from reshape
718     // nodes following it.
719     for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
720       full_reduction = false;
721       if (!IsReshape(*fanout)) {
722         return Status::OK();
723       }
724       const std::vector<OpInfo::TensorProperties>& reshape_props =
725           properties.GetOutputProperties(fanout->name());
726       if (reshape_props.size() != 1) {
727         return Status::OK();
728       }
729       const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
730       PartialTensorShape shape(reshape_prop.shape());
731       if (shape.num_elements() != 1) {
732         return Status::OK();
733       } else {
734         full_reduction = true;
735       }
736     }
737     if (!full_reduction) {
738       return Status::OK();
739     }
740   }
741 
742   // We know it's a full reduction. We can generate the full set of indices to
743   // reduce as a constant node.
744   string const_name = OptimizedNodeName(*node, "-reduction_indices");
745   if (node_map_->GetNode(const_name)) {
746     return Status::OK();
747   }
748   NodeDef* reduction_indices = graph_->add_node();
749   Tensor value(dtype, TensorShape({input_rank}));
750   for (int i = 0; i < input_rank; ++i) {
751     if (dtype == DT_INT32) {
752       value.vec<int32>()(i) = i;
753     } else {
754       value.vec<int64>()(i) = i;
755     }
756   }
757   TF_RETURN_IF_ERROR(
758       CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
759 
760   reduction_indices->set_device(node->device());
761   string ctrl_dep =
762       AddControlDependency(node->input(1), graph_, node_map_.get());
763   *reduction_indices->add_input() = ctrl_dep;
764   node_map_->AddNode(const_name, reduction_indices);
765   node_map_->AddOutput(NodeName(ctrl_dep), const_name);
766 
767   node->set_input(1, reduction_indices->name());
768   node_map_->UpdateInput(node->name(), indices->name(),
769                          reduction_indices->name());
770 
771   return Status::OK();
772 }
773 
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)774 Status ConstantFolding::MaterializeConstantValuedNode(
775     NodeDef* node, const GraphProperties& properties) {
776   // Nodes that generate constant-valued outputs can be represented compactly in
777   // compressed format, regardless of their shape.
778   const std::vector<OpInfo::TensorProperties>& output_props =
779       properties.GetOutputProperties(node->name());
780   if (output_props.size() != 1) return Status::OK();
781   const auto& output_shape = output_props[0].shape();
782   if (!PartialTensorShape(output_shape).IsFullyDefined()) {
783     return Status::OK();
784   }
785   if (IsFill(*node)) {
786     const auto output_dtype = output_props[0].dtype();
787     NodeDef* input_node = nullptr;
788     for (int i = 0; i < 2; ++i) {
789       input_node = node_map_->GetNode(NodeName(node->input(i)));
790       if (input_node == nullptr || !IsReallyConstant(*input_node)) {
791         return Status::OK();
792       }
793     }
794     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
795 
796     // Copy the input tensor to the fill node, set the output shape and data
797     // type, and change the node type to Const.
798     TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
799     const TensorProto& input_tensor = input_node->attr().at("value").tensor();
800     if (!input_tensor.tensor_content().empty()) {
801       // Convert the value to repeated field format, so we can use the
802       // decompression mechanism to store only a single value in the constant
803       // node, even if the shape specified in the original Fill is large.
804       Tensor t;
805       if (!t.FromProto(input_tensor)) {
806         return errors::InvalidArgument(
807             "Could not construct Tensor form TensorProto in node: ",
808             input_node->name());
809       }
810       tensor->clear_tensor_content();
811       t.AsProtoField(tensor);
812     } else {
813       *tensor = input_tensor;
814     }
815     *(tensor->mutable_tensor_shape()) = output_shape;
816     (*node->mutable_attr())["dtype"].set_type(output_dtype);
817     node->mutable_attr()->erase("T");
818     node->mutable_attr()->erase("index_type");
819     node->set_op("Const");
820     for (int i = 0; i < 2; i++) {
821       // Change inputs to a control inputs.
822       const string ctrl_dep = AsControlDependency(node->input(i));
823       node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
824       node->set_input(i, ctrl_dep);
825     }
826     graph_modified_ = true;
827   } else {
828     double value =
829         (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
830     bool success = false;
831     if (value >= 0) {
832       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
833           value, properties, output_shape, node, graph_, &success));
834     }
835   }
836   return Status::OK();
837 }
838 
MaterializeConstants(const GraphProperties & properties)839 Status ConstantFolding::MaterializeConstants(
840     const GraphProperties& properties) {
841   const int node_count = graph_->node_size();
842   for (int i = 0; i < node_count; ++i) {
843     NodeDef& node = *graph_->mutable_node(i);
844     const string& op = node.op();
845     if (op == "BroadcastGradientArgs") {
846       TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
847     } else if (IsReduction(node)) {
848       TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
849     } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
850       TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
851     }
852   }
853   return Status::OK();
854 }
855 
IsFoldable(const NodeDef & node) const856 bool ConstantFolding::IsFoldable(const NodeDef& node) const {
857   // Folding not applicable to ops with no inputs.
858   if (node.input().empty()) {
859     return false;
860   }
861   // Skips nodes that must be preserved except whitelisted nodes.
862   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
863       nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
864     return false;
865   }
866   // `FakeParam` op is used as a placeholder in If branch function. It doesn't
867   // have a valid output when executed.
868   if (IsFakeParam(node)) {
869     return false;
870   }
871 
872   // Skip control flow nodes, they can't be folded.
873   if (ModifiesFrameInfo(node)) {
874     return false;
875   }
876 
877   // Removing LoopCond nodes can screw up the partitioner.
878   if (node.op() == "LoopCond") {
879     return false;
880   }
881 
882   // Skip constants, they're already folded
883   if (IsConstant(node)) {
884     return false;
885   }
886 
887   // Don't fold stateful ops such as TruncatedNormal.
888   if (!IsFreeOfSideEffect(node)) {
889     return false;
890   }
891 
892   // Skips ops that don't benefit from folding.
893   if (IsPlaceholder(node)) {
894     return false;
895   }
896   const string& op = node.op();
897   if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
898       op.find("Reader") != string::npos) {
899     return false;
900   }
901   if (op.find("Quantized") != string::npos || op.find("Sparse") == 0) {
902     return false;
903   }
904 
905   // Don't fold nodes that contain TPU attributes.
906   // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
907   // properly forward custom attributes, b/119051778.
908   if (HasTPUAttributes(node)) {
909     return false;
910   }
911 
912   const OpDef* op_def = nullptr;
913   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
914   if (!status.ok()) {
915     return false;
916   }
917   // Don't fold ops without outputs.
918   if (op_def->output_arg_size() == 0) {
919     return false;
920   }
921 
922   // No need to (and don't) fold nodes that have no outgoing edges except
923   // whitelisted nodes. Such nodes could be introduced by an earlier constant
924   // folding pass and are preserved in case users want to fetch their values;
925   // re-processing them would lead to an error of adding a duplicated node
926   // to graph.
927   auto outputs = node_map_->GetOutputs(node.name());
928   if (outputs.empty() &&
929       nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
930     return false;
931   }
932 
933   // We can only fold nodes if all their inputs are known statically, except in
934   // the case of a merge node that propagate the first inputs that becomes
935   // available, and therefore only requires a single constant input to be
936   // foldable.
937   bool merge_has_constant_input = false;
938   const bool is_merge = IsMerge(node);
939   for (const auto& input : node.input()) {
940     if (IsControlInput(input)) {
941       continue;
942     }
943     const NodeDef* input_node = node_map_->GetNode(input);
944     if (!input_node) {
945       return false;
946     }
947     bool is_const = IsReallyConstant(*input_node);
948     if (is_const) {
949       // Don't fold strings constants for now since this causes problems with
950       // checkpointing.
951       if (input_node->attr().count("dtype") == 0 ||
952           input_node->attr().at("dtype").type() == DT_STRING) {
953         return false;
954       }
955       // Special case: If a Merge node has at least one constant input that
956       // does not depend on a control input, we can fold it.
957       merge_has_constant_input |= !HasControlInputs(*input_node);
958     } else if (!is_merge) {
959       return false;
960     }
961   }
962   return !is_merge || merge_has_constant_input;
963 }
964 
965 namespace {
966 
967 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME)     \
968   case DTYPE:                                      \
969     t->add_##NAME##_val(static_cast<TYPE>(value)); \
970     break;
971 
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)972 Status CreateConstantTensorAttrValue(DataType type, double value,
973                                      const TensorShapeProto& shape,
974                                      AttrValue* attr_tensor) {
975   TensorProto* t = attr_tensor->mutable_tensor();
976   t->set_dtype(type);
977   *t->mutable_tensor_shape() = shape;
978   switch (type) {
979     case DT_HALF:
980       t->add_half_val(static_cast<Eigen::half>(value).x);
981       break;
982     case DT_BFLOAT16:
983       t->add_half_val(static_cast<bfloat16>(value).value);
984       break;
985       SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
986       SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
987       SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
988       SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
989       SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
990       SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
991       SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
992       SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
993       SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
994       SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
995       SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
996       SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
997       SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
998       SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
999       SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1000       SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1001     default:
1002       return errors::InvalidArgument("Unsupported type: ", type);
1003   }
1004   return Status::OK();
1005 }
1006 
1007 #undef SET_TENSOR_CAL_CASE
1008 
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1009 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1010                                     const GraphProperties& properties) {
1011   DataType dtype = DT_INVALID;
1012   if (node.attr().count("T") == 1) {
1013     dtype = node.attr().at("T").type();
1014   } else if (node.attr().count("dtype") == 1) {
1015     dtype = node.attr().at("dtype").type();
1016   } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1017     dtype = DT_BOOL;
1018   } else {
1019     auto output_props = properties.GetOutputProperties(node.name());
1020     if (!output_props.empty()) {
1021       dtype = output_props[0].dtype();
1022     }
1023   }
1024   return dtype;
1025 }
1026 
1027 // Checks whether the shape of the const input of the Mul op is valid to perform
1028 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1029 bool IsValidConstShapeForMulConvPushDown(
1030     const string& data_format, const TensorShapeProto& filter_shape,
1031     const TensorShapeProto& mul_const_input_shape) {
1032   // If the const is a scalar, or it has fewer or same number of dimensions
1033   // than the filter and it only has single element, the optimization should
1034   // work.
1035   if (mul_const_input_shape.dim_size() <= data_format.size() &&
1036       TensorShape(mul_const_input_shape).num_elements() == 1) {
1037     return true;
1038   }
1039 
1040   // Otherwise, check the eligibility according to data format.
1041   if (data_format == "NHWC" || data_format == "NDHWC") {
1042     TensorShapeProto new_filter_shape;
1043     if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1044                              &new_filter_shape)) {
1045       return false;
1046     }
1047     if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1048       return false;
1049     }
1050     // Only the last dimension could be larger than one, since broadcasting over
1051     // the last dimension (the output channel) will result in invalid filter.
1052     for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1053       if (mul_const_input_shape.dim(i).size() > 1) return false;
1054     }
1055     return true;
1056   } else if (data_format == "NCHW" || data_format == "NCDHW") {
1057     // TODO(laigd): support NCHW and NCDHW (b/111214513).
1058     return false;
1059   }
1060   return false;
1061 }
1062 
1063 }  // namespace
1064 
1065 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1066 Status ConstantFolding::CreateNodeDef(const string& name,
1067                                       const TensorValue& tensor, NodeDef* node,
1068                                       size_t original_size) {
1069   node->set_name(name);
1070   node->set_op("Const");
1071 
1072   AttrValue attr_type;
1073   attr_type.set_type(tensor->dtype());
1074   node->mutable_attr()->insert({"dtype", attr_type});
1075 
1076   AttrValue attr_tensor;
1077   TensorProto* t = attr_tensor.mutable_tensor();
1078   bool optimized = false;
1079   size_t encoded_size;
1080   // Use the packed representation whenever possible to avoid generating large
1081   // graphdefs. Moreover, avoid repeating the last values if they're equal.
1082   if (tensor->NumElements() > 4) {
1083 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME)                      \
1084   {                                                                       \
1085     const auto* val_ptr = tensor->flat<TYPE>().data();                    \
1086     auto last = *val_ptr;                                                 \
1087     int64 last_index = 0;                                                 \
1088     for (int64 i = 0; i < tensor->NumElements(); ++i) {                   \
1089       TYPE cur = *val_ptr++;                                              \
1090       if (PackedValuesNotEqual(cur, last)) {                              \
1091         last = cur;                                                       \
1092         last_index = i;                                                   \
1093       }                                                                   \
1094     }                                                                     \
1095     if (last_index < kint32max) {                                         \
1096       optimized = true;                                                   \
1097       encoded_size = (last_index + 1) * sizeof(NAME);                     \
1098       t->mutable_##NAME##_val()->Reserve(last_index + 1);                 \
1099       const auto* src_ptr = tensor->flat<TYPE>().data();                  \
1100       auto* dst_ptr =                                                     \
1101           t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \
1102       std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr);              \
1103     }                                                                     \
1104   }                                                                       \
1105   break
1106 
1107     switch (tensor->dtype()) {
1108       case DT_FLOAT:
1109         POPULATE_TENSOR_PROTO(tensor, t, float, float);
1110       case DT_DOUBLE:
1111         POPULATE_TENSOR_PROTO(tensor, t, double, double);
1112       case DT_INT64:
1113         POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
1114       case DT_UINT64:
1115         POPULATE_TENSOR_PROTO(tensor, t, uint64, int64);
1116       case DT_INT32:
1117         POPULATE_TENSOR_PROTO(tensor, t, int32, int);
1118       case DT_UINT32:
1119         POPULATE_TENSOR_PROTO(tensor, t, uint32, int);
1120       case DT_INT16:
1121         POPULATE_TENSOR_PROTO(tensor, t, int16, int);
1122       case DT_UINT16:
1123         POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1124       case DT_INT8:
1125         POPULATE_TENSOR_PROTO(tensor, t, int8, int);
1126       case DT_UINT8:
1127         POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1128       case DT_BOOL:
1129         POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1130       default:
1131         /* Do nothing. */
1132         break;
1133     }
1134   }
1135   if (optimized) {
1136     // Also specify type and shape.
1137     t->set_dtype(tensor->dtype());
1138     tensor->shape().AsProto(t->mutable_tensor_shape());
1139   } else {
1140     // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1141     // DT_QUINT8
1142     tensor->AsProtoTensorContent(t);
1143     encoded_size = t->tensor_content().size();
1144   }
1145   node->mutable_attr()->insert({"value", attr_tensor});
1146 
1147   if (encoded_size > original_size && encoded_size >= 10 * 1024 * 1024) {
1148     return errors::InvalidArgument(
1149         strings::StrCat("Can't fold ", name, ", its size would be too large (",
1150                         encoded_size, " >= ", 10 * 1024 * 1024, " bytes)"));
1151   }
1152   return Status::OK();
1153 }
1154 
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1155 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1156                                      const TensorVector& inputs,
1157                                      TensorVector* output) const {
1158   return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1159                                               resource_mgr_.get(), output);
1160 }
1161 
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1162 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1163                                             std::vector<NodeDef>* outputs,
1164                                             bool* result_too_large) {
1165   TensorVector inputs;
1166   TensorVector output_tensors;
1167   auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1168     for (const auto& input : inputs) {
1169       delete input.tensor;
1170     }
1171     for (const auto& output : output_tensors) {
1172       if (output.tensor) {
1173         delete output.tensor;
1174       }
1175     }
1176   });
1177 
1178   size_t total_inputs_size = 0;
1179   for (const auto& input : node.input()) {
1180     const TensorId input_tensor = ParseTensorName(input);
1181     if (input_tensor.index() < 0) {
1182       // Control dependency
1183       break;
1184     }
1185     const NodeDef* input_node = node_map_->GetNode(input);
1186     if (!IsReallyConstant(*input_node)) {
1187       return Status(error::INVALID_ARGUMENT,
1188                     strings::StrCat("Can't fold ", node.name(), ", its ", input,
1189                                     " isn't constant"));
1190     }
1191     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1192     const TensorProto& raw_val = input_node->attr().at("value").tensor();
1193     Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1194     CHECK(value->FromProto(raw_val));
1195     inputs.emplace_back(value);
1196     total_inputs_size += value->TotalBytes();
1197   }
1198 
1199   TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1200   if (output_tensors.empty()) {
1201     return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1202   }
1203 
1204   outputs->resize(output_tensors.size());
1205   for (size_t i = 0; i < output_tensors.size(); i++) {
1206     string node_name = OptimizedNodeName(node, "-folded");
1207     if (output_tensors.size() > 1) {
1208       node_name = strings::StrCat(node_name, "-", i);
1209     }
1210     if (output_tensors[i].tensor) {
1211       Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1212                                total_inputs_size);
1213       if (!s.ok()) {
1214         *result_too_large = true;
1215         return s;
1216       }
1217     } else {
1218       // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1219       // switch that's not selected by the switch predicate).
1220       outputs->at(i) = NodeDef();
1221     }
1222   }
1223   return Status::OK();
1224 }
1225 
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1226 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1227   // Merge nodes are special, in the sense that they execute as soon as one of
1228   // their input is ready. We can therefore fold a merge node iff it has at
1229   // least one constant input without control dependency.
1230   // We still need to ensure that the nodes in the fanin of the merge node are
1231   // scheduled. We'll therefore add a control dependency from the merge node
1232   // to the folded constant. We end up with:
1233   //  * the merge node and its inputs are preserved as is
1234   //  * a new constant node C1, driven by the merge node through a control
1235   //  dependency, initialized to the value of the folded input
1236   //  * a new constant node C2, driven by the merge node through a control
1237   //  dependency, initialized to the index of the folded input
1238   //  * the fanout of the merge nodes is rewired to be driven by either C1 or
1239   //  C2.
1240   for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1241     const auto& input = node->input(input_index);
1242     if (IsControlInput(input)) {
1243       // Try the next input.
1244       continue;
1245     }
1246     NodeDef* input_node = node_map_->GetNode(input);
1247     if (!IsReallyConstant(*input_node)) {
1248       continue;
1249     }
1250     bool valid_input = true;
1251     for (const string& fanin_of_input : input_node->input()) {
1252       if (IsControlInput(fanin_of_input)) {
1253         valid_input = false;
1254         break;
1255       }
1256     }
1257     if (!valid_input) {
1258       // Try the next input
1259       continue;
1260     }
1261 
1262     string const_out_name = OptimizedNodeName(*node, "_const");
1263     string const_index_name = OptimizedNodeName(*node, "_index");
1264     if (node_map_->GetNode(const_out_name) ||
1265         node_map_->GetNode(const_index_name)) {
1266       // Intended name already exists.
1267       return errors::AlreadyExists(
1268           strings::StrCat(const_out_name, " or ", const_index_name,
1269                           " already present in the graph"));
1270     }
1271 
1272     NodeDef* const_out = output_graph->add_node();
1273     *const_out = *input_node;
1274     const_out->set_name(const_out_name);
1275     const_out->set_device(node->device());
1276     *const_out->add_input() = AsControlDependency(*node);
1277     node_map_->AddNode(const_out->name(), const_out);
1278     node_map_->AddOutput(node->name(), const_out->name());
1279 
1280     NodeDef* const_index = output_graph->add_node();
1281     const_index->set_op("Const");
1282     Tensor index(DT_INT32, TensorShape({}));
1283     index.flat<int32>()(0) = input_index;
1284     (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1285     index.AsProtoTensorContent(
1286         (*const_index->mutable_attr())["value"].mutable_tensor());
1287     const_index->set_name(const_index_name);
1288     const_index->set_device(node->device());
1289     *const_index->add_input() = AsControlDependency(*node);
1290     node_map_->AddNode(const_index->name(), const_index);
1291     node_map_->AddOutput(node->name(), const_index->name());
1292 
1293     auto outputs = node_map_->GetOutputs(node->name());
1294     for (NodeDef* output : outputs) {
1295       for (int i = 0; i < output->input_size(); i++) {
1296         int port;
1297         string node_name = ParseNodeName(output->input(i), &port);
1298         if (node_name == node->name()) {
1299           if (port == 0) {
1300             *output->mutable_input(i) = const_out->name();
1301             node_map_->AddOutput(const_out->name(), output->name());
1302           } else if (port == 1) {
1303             *output->mutable_input(i) = const_index->name();
1304             node_map_->AddOutput(const_index->name(), output->name());
1305           } else {
1306             // This is a control dependency (or an invalid edge since the
1307             // merge node has only 2 inputs): preserve them.
1308           }
1309         }
1310       }
1311     }
1312     return Status::OK();
1313   }
1314   return Status::OK();
1315 }
1316 
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1317 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1318                                  bool* result_too_large) {
1319   *result_too_large = false;
1320   if (IsMerge(*node)) {
1321     return FoldMergeNode(node, output_graph);
1322   }
1323 
1324   std::vector<NodeDef> const_nodes;
1325   TF_RETURN_IF_ERROR(
1326       EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1327   VLOG(1) << "Folded node:\n" << node->DebugString();
1328 
1329   NodeDef* constant_output = nullptr;
1330   for (int i = 0; i < const_nodes.size(); i++) {
1331     NodeDef* const_node = &const_nodes[i];
1332     VLOG(1) << "Generated constant node:\n" << const_node->DebugString();
1333     if (const_node->name().empty()) {
1334       // Dead output: we can't create a constant to encode its value, so we'll
1335       // just skip it. We'll preserve the edges that originate from that
1336       // output below to preserve the overall behavior of the graph wrt dead
1337       // edges.
1338       continue;
1339     }
1340 
1341     // Forward control dependencies.
1342     for (const auto& input : node->input()) {
1343       if (IsControlInput(input) &&
1344           std::find(const_node->input().begin(), const_node->input().end(),
1345                     input) == const_node->input().end()) {
1346         *const_node->add_input() = input;
1347       } else {
1348         NodeDef* input_node = node_map_->GetNode(input);
1349         for (const auto& fanin_of_input : input_node->input()) {
1350           if (IsControlInput(fanin_of_input) &&
1351               std::find(const_node->input().begin(), const_node->input().end(),
1352                         fanin_of_input) == const_node->input().end()) {
1353             *const_node->add_input() = fanin_of_input;
1354           }
1355         }
1356       }
1357     }
1358 
1359     // We rewrite the existing node if it only has a single output, and
1360     // create new nodes otherwise.
1361     if (const_nodes.size() == 1) {
1362       node->set_op("Const");
1363       // Note we need to clear the inputs in NodeMap before we clear the inputs
1364       // in the node, otherwise NodeMap would see empty inputs and effectively
1365       // does nothing.
1366       node_map_->RemoveInputs(node->name());
1367       node->clear_input();
1368       *node->mutable_input() = const_node->input();
1369       for (const auto& input : node->input()) {
1370         node_map_->AddOutput(NodeName(input), node->name());
1371       }
1372       *node->mutable_attr() = const_node->attr();
1373       break;
1374     } else {
1375       if (node_map_->GetNode(const_node->name())) {
1376         // Intended name already exists.
1377         return errors::AlreadyExists(strings::StrCat(
1378             const_node->name(), " already present in the graph"));
1379       }
1380       NodeDef* added_node = output_graph->add_node();
1381       *added_node = *const_node;
1382       added_node->set_device(node->device());
1383       node_map_->AddNode(added_node->name(), added_node);
1384       for (const auto& input : added_node->input()) {
1385         node_map_->AddOutput(NodeName(input), added_node->name());
1386       }
1387       // All the constant nodes encoding output values have the same control
1388       // dependencies (since these are the control dependencies of the node
1389       // we're trying to fold). Record one such constant node.
1390       constant_output = added_node;
1391     }
1392   }
1393 
1394   if (const_nodes.size() > 1) {
1395     auto outputs = node_map_->GetOutputs(node->name());
1396     for (NodeDef* output : outputs) {
1397       for (int i = 0; i < output->input_size(); i++) {
1398         int port;
1399         string node_name = ParseNodeName(output->input(i), &port);
1400         if (node_name == node->name()) {
1401           if (port < 0) {
1402             // Propagate control dependencies if possible. If not, we'll just
1403             // preserve the existing control dependencies.
1404             if (constant_output != nullptr) {
1405               node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1406                                      constant_output->name());
1407               *output->mutable_input(i) = AsControlDependency(*constant_output);
1408             }
1409           } else if (port < const_nodes.size() &&
1410                      !const_nodes[port].name().empty()) {
1411             // Replace alive outputs with the corresponding constant.
1412             node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1413                                    const_nodes[port].name());
1414             *output->mutable_input(i) = const_nodes[port].name();
1415           } else {
1416             // Leave this edge alone.
1417             VLOG(1) << "Preserving edge from " << node->name() << ":" << port
1418                     << "[" << node->op() << "] to " << output->name() << ":"
1419                     << i << "[" << output->op() << "]";
1420           }
1421         }
1422       }
1423     }
1424     outputs = node_map_->GetOutputs(node->name());
1425     if (outputs.empty() && has_fetch_ &&
1426         nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1427       node_map_->RemoveInputs(node->name());
1428       node->clear_input();
1429     }
1430   }
1431   return Status::OK();
1432 }
1433 
FoldGraph(GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1434 Status ConstantFolding::FoldGraph(
1435     GraphDef* output, absl::flat_hash_set<string>* nodes_to_not_simplify) {
1436   std::unordered_set<string> processed_nodes;
1437   std::deque<NodeDef*> queue;
1438   for (int i = 0; i < graph_->node_size(); i++) {
1439     if (IsFoldable(graph_->node(i))) {
1440       queue.push_back(graph_->mutable_node(i));
1441     }
1442   }
1443   while (!queue.empty()) {
1444     NodeDef* node = queue.front();
1445     queue.pop_front();
1446     if (processed_nodes.count(node->name())) {
1447       continue;
1448     }
1449     // We need to record a copy of output nodes before FoldNode() modifies it.
1450     // We also need to ensure that the fanout is sorted deterministically.
1451     const std::set<NodeDef*>& outputs = node_map_->GetOutputs(node->name());
1452     std::vector<NodeDef*> fanout(outputs.begin(), outputs.end());
1453     std::sort(fanout.begin(), fanout.end(),
1454               [](const NodeDef* n1, const NodeDef* n2) {
1455                 return n1->name() < n2->name();
1456               });
1457 
1458     bool result_too_large = false;
1459     Status s = FoldNode(node, output, &result_too_large);
1460     processed_nodes.insert(node->name());
1461     if (!s.ok()) {
1462       VLOG(1) << "Failed to fold node " << node->DebugString()
1463               << "\nError message: " << s;
1464       if (result_too_large) {
1465         nodes_to_not_simplify->emplace(node->name());
1466       }
1467     } else {
1468       for (auto& output : fanout) {
1469         if (IsFoldable(*output)) {
1470           queue.push_back(output);
1471         }
1472       }
1473     }
1474   }
1475 
1476   // Delete the newly created nodes that don't feed anything.
1477   std::vector<int> nodes_to_delete;
1478   for (int i = 0; i < output->node_size(); i++) {
1479     auto fanout = node_map_->GetOutputs(output->node(i).name());
1480     if (fanout.empty()) nodes_to_delete.push_back(i);
1481   }
1482   EraseNodesFromGraph(std::move(nodes_to_delete), output);
1483 
1484   for (const auto& node : graph_->node()) {
1485     // If no fetch nodes is provided, we conservatively
1486     // keep all nodes in the original graph in case users need to fetch
1487     // their values.
1488     auto fanout = node_map_->GetOutputs(node.name());
1489     if (!fanout.empty() || !has_fetch_ ||
1490         nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
1491       auto added_node = output->add_node();
1492       *added_node = node;
1493     }
1494   }
1495   return Status::OK();
1496 }
1497 
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1498 bool ConstantFolding::IsSimplifiableReshape(
1499     const NodeDef& node, const GraphProperties& properties) const {
1500   if (!IsReshape(node)) {
1501     return false;
1502   }
1503   CHECK_LE(2, node.input_size());
1504   const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1505   if (!IsReallyConstant(*new_shape)) {
1506     return false;
1507   }
1508   TensorVector outputs;
1509   auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1510     for (const auto& output : outputs) {
1511       delete output.tensor;
1512     }
1513   });
1514 
1515   Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1516   if (!s.ok()) {
1517     return false;
1518   }
1519   CHECK_EQ(1, outputs.size());
1520 
1521   const std::vector<OpInfo::TensorProperties>& props =
1522       properties.GetInputProperties(node.name());
1523   if (props.empty()) {
1524     return false;
1525   }
1526   const OpInfo::TensorProperties& prop = props[0];
1527   if (prop.dtype() == DT_INVALID) {
1528     return false;
1529   }
1530   const PartialTensorShape shape(prop.shape());
1531   if (!shape.IsFullyDefined()) {
1532     return false;
1533   }
1534 
1535   PartialTensorShape new_dims;
1536   if (outputs[0]->dtype() == DT_INT32) {
1537     std::vector<int32> shp;
1538     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1539       int32 dim = outputs[0]->flat<int32>()(i);
1540       shp.push_back(dim);
1541     }
1542     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1543   } else {
1544     std::vector<int64> shp;
1545     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1546       int64 dim = outputs[0]->flat<int64>()(i);
1547       shp.push_back(dim);
1548     }
1549     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1550   }
1551 
1552   return shape.IsCompatibleWith(new_dims);
1553 }
1554 
1555 #define IS_VALUE_CASE(DTYPE, VALUE)                   \
1556   case DTYPE:                                         \
1557     return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1558         node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1559 
1560 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1561 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1562 
IsOnes(const NodeDef & node) const1563 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1564   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1565     return false;
1566   }
1567   if (IsOnesLike(node)) return true;
1568   if (IsZerosLike(node)) return false;
1569   if (node.op() == "Fill") {
1570     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1571     return values != nullptr && IsOnes(*values);
1572   }
1573   if (node.op() != "Const") return false;
1574   if (node.attr().count("dtype") == 0) return false;
1575   const auto dtype = node.attr().at("dtype").type();
1576   switch (dtype) {
1577     IS_ONES_CASE(DT_BOOL);
1578     IS_ONES_CASE(DT_HALF);
1579     IS_ONES_CASE(DT_BFLOAT16);
1580     IS_ONES_CASE(DT_FLOAT);
1581     IS_ONES_CASE(DT_DOUBLE);
1582     IS_ONES_CASE(DT_COMPLEX64);
1583     IS_ONES_CASE(DT_COMPLEX128);
1584     IS_ONES_CASE(DT_UINT8);
1585     IS_ONES_CASE(DT_INT8);
1586     IS_ONES_CASE(DT_UINT16);
1587     IS_ONES_CASE(DT_INT16);
1588     IS_ONES_CASE(DT_INT32);
1589     IS_ONES_CASE(DT_INT64);
1590     IS_ONES_CASE(DT_QINT32);
1591     IS_ONES_CASE(DT_QINT16);
1592     IS_ONES_CASE(DT_QUINT16);
1593     IS_ONES_CASE(DT_QINT8);
1594     IS_ONES_CASE(DT_QUINT8);
1595     default:
1596       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1597       return false;
1598   }
1599   return false;
1600 }
1601 
IsZeros(const NodeDef & node) const1602 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1603   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1604     return false;
1605   }
1606   if (IsOnesLike(node)) return false;
1607   if (IsZerosLike(node)) return true;
1608   if (node.op() == "Fill") {
1609     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1610     return values != nullptr && IsZeros(*values);
1611   }
1612   if (!IsConstant(node)) return false;
1613   if (node.attr().count("dtype") == 0) return false;
1614   const auto dtype = node.attr().at("dtype").type();
1615   switch (dtype) {
1616     IS_ZEROS_CASE(DT_BOOL);
1617     IS_ZEROS_CASE(DT_HALF);
1618     IS_ZEROS_CASE(DT_BFLOAT16);
1619     IS_ZEROS_CASE(DT_FLOAT);
1620     IS_ZEROS_CASE(DT_DOUBLE);
1621     IS_ZEROS_CASE(DT_COMPLEX64);
1622     IS_ZEROS_CASE(DT_COMPLEX128);
1623     IS_ZEROS_CASE(DT_UINT8);
1624     IS_ZEROS_CASE(DT_INT8);
1625     IS_ZEROS_CASE(DT_UINT16);
1626     IS_ZEROS_CASE(DT_INT16);
1627     IS_ZEROS_CASE(DT_INT32);
1628     IS_ZEROS_CASE(DT_INT64);
1629     IS_ZEROS_CASE(DT_QINT32);
1630     IS_ZEROS_CASE(DT_QINT16);
1631     IS_ZEROS_CASE(DT_QUINT16);
1632     IS_ZEROS_CASE(DT_QINT8);
1633     IS_ZEROS_CASE(DT_QUINT8);
1634     default:
1635       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1636       return false;
1637   }
1638   return false;
1639 }
1640 
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1641 void ConstantFolding::ReplaceOperationWithIdentity(
1642     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1643     GraphDef* graph) {
1644   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1645   if (dtype == DT_INVALID) return;
1646 
1647   node->set_op("Identity");
1648   node->clear_attr();
1649   (*node->mutable_attr())["T"].set_type(dtype);
1650   // Propagate the designated input through the identity.
1651   node->mutable_input()->SwapElements(0, input_to_forward);
1652   // Add all other inputs as control dependencies.
1653   for (int i = 1; i < node->input_size(); ++i) {
1654     if (IsControlInput(node->input(i))) {
1655       break;
1656     }
1657     const string ctrl_dep =
1658         AddControlDependency(node->input(i), graph, node_map_.get());
1659     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1660     node->set_input(i, ctrl_dep);
1661   }
1662   graph_modified_ = true;
1663 }
1664 
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1665 void ConstantFolding::ReplaceOperationWithSnapshot(
1666     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1667     GraphDef* graph) {
1668   // If the graph contains no ops that mutate their inputs, we can
1669   // use Identity insted of Snapshot.
1670   if (!graph_contains_assign_or_inplace_op_) {
1671     ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1672     return;
1673   }
1674 
1675   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1676   if (dtype == DT_INVALID) return;
1677 
1678   node->set_op("Snapshot");
1679   node->clear_attr();
1680   (*node->mutable_attr())["T"].set_type(dtype);
1681   // Propagate the designated input through the Snapshot.
1682   node->mutable_input()->SwapElements(0, input_to_forward);
1683   // Add all other inputs as control dependencies.
1684   for (int i = 1; i < node->input_size(); ++i) {
1685     if (IsControlInput(node->input(i))) {
1686       break;
1687     }
1688     const string ctrl_dep =
1689         AddControlDependency(node->input(i), graph, node_map_.get());
1690     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1691     node->set_input(i, ctrl_dep);
1692   }
1693   graph_modified_ = true;
1694 }
1695 
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1696 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1697                                                         GraphDef* graph) {
1698   node->set_op("Reciprocal");
1699   node->mutable_input()->SwapElements(0, 1);
1700   const string ctrl_dep =
1701       AddControlDependency(node->input(1), graph, node_map_.get());
1702   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1703   node->set_input(1, ctrl_dep);
1704   graph_modified_ = true;
1705 }
1706 
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1707 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1708                                                            GraphDef* graph) {
1709   node->set_op("Neg");
1710   node->mutable_input()->SwapElements(0, 1);
1711   const string ctrl_dep =
1712       AddControlDependency(node->input(1), graph, node_map_.get());
1713   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1714   node->set_input(1, ctrl_dep);
1715   graph_modified_ = true;
1716 }
1717 
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph,bool * success)1718 Status ConstantFolding::ReplaceOperationWithConstant(
1719     double value, const GraphProperties& properties,
1720     const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
1721     bool* success) {
1722   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1723   if (dtype == DT_INVALID) {
1724     *success = false;
1725     return Status::OK();
1726   }
1727 
1728   AttrValue tensor_attr;
1729   TF_RETURN_IF_ERROR(
1730       CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr));
1731   node->set_op("Const");
1732   node->clear_attr();
1733   (*node->mutable_attr())["dtype"].set_type(dtype);
1734   node->mutable_attr()->insert({"value", tensor_attr});
1735   // Convert all inputs to control dependencies.
1736   for (int i = 0; i < node->input_size(); ++i) {
1737     if (IsControlInput(node->input(i))) {
1738       break;
1739     }
1740     const string ctrl_dep =
1741         AddControlDependency(node->input(i), graph, node_map_.get());
1742     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1743     node->set_input(i, ctrl_dep);
1744   }
1745   *success = true;
1746   graph_modified_ = true;
1747   return Status::OK();
1748 }
1749 
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)1750 Status ConstantFolding::SimplifyGraph(
1751     bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
1752     absl::flat_hash_set<string>* nodes_to_not_simplify) {
1753   for (int i = 0; i < optimized_graph->node_size(); ++i) {
1754     NodeDef* node = optimized_graph->mutable_node(i);
1755     // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
1756     // generalize to only restrict certain simplifications.
1757     if (nodes_to_not_simplify->find(node->name()) ==
1758         nodes_to_not_simplify->end()) {
1759       if (HasTPUAttributes(optimized_graph->node(i))) {
1760         nodes_to_not_simplify->insert(node->name());
1761         continue;
1762       }
1763       TF_RETURN_IF_ERROR(
1764           SimplifyNode(use_shape_info, node, optimized_graph, properties));
1765     }
1766   }
1767   return Status::OK();
1768 }
1769 
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)1770 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
1771                                      GraphDef* optimized_graph,
1772                                      GraphProperties* properties) {
1773   if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
1774     return Status::OK();
1775   }
1776 
1777   bool remove_shuffle_transpose_successful = false;
1778   Status remove_shuffle_transpose_status =
1779       RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
1780                                node, &remove_shuffle_transpose_successful);
1781   if (!remove_shuffle_transpose_status.ok()) {
1782     return remove_shuffle_transpose_status;
1783   } else if (remove_shuffle_transpose_successful) {
1784     return Status::OK();
1785   }
1786 
1787   if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
1788     return Status::OK();
1789   }
1790 
1791   bool remove_reverse_successful = false;
1792   Status remove_reverse_status =
1793       RemoveReverse(*properties, use_shape_info, optimized_graph, node,
1794                     &remove_reverse_successful);
1795   if (!remove_reverse_status.ok()) {
1796     return remove_reverse_status;
1797   } else if (remove_reverse_successful) {
1798     return Status::OK();
1799   }
1800 
1801   bool simplify_slice_successful = false;
1802   Status simplify_slice_status =
1803       SimplifySlice(*properties, use_shape_info, optimized_graph, node,
1804                     &simplify_slice_successful);
1805   if (!simplify_slice_status.ok()) {
1806     return simplify_slice_status;
1807   } else if (simplify_slice_successful) {
1808     return Status::OK();
1809   }
1810 
1811   bool simplify_strided_slice_successful = false;
1812   Status simplify_strided_slice_status =
1813       SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
1814                            &simplify_strided_slice_successful);
1815   if (!simplify_strided_slice_status.ok()) {
1816     return simplify_strided_slice_status;
1817   } else if (simplify_strided_slice_successful) {
1818     return Status::OK();
1819   }
1820 
1821   bool simplify_tile_successful = false;
1822   Status simplify_tile_status =
1823       SimplifyTile(*properties, use_shape_info, optimized_graph, node,
1824                    &simplify_tile_successful);
1825   if (!simplify_tile_status.ok()) {
1826     return simplify_tile_status;
1827   } else if (simplify_tile_successful) {
1828     return Status::OK();
1829   }
1830 
1831   bool simplify_pad_successful = false;
1832   Status simplify_pad_status =
1833       SimplifyPad(*properties, use_shape_info, optimized_graph, node,
1834                   &simplify_pad_successful);
1835   if (!simplify_pad_status.ok()) {
1836     return simplify_pad_status;
1837   } else if (simplify_pad_successful) {
1838     return Status::OK();
1839   }
1840 
1841   if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
1842     return Status::OK();
1843   }
1844 
1845   if (SimplifyPack(optimized_graph, node)) {
1846     graph_modified_ = true;
1847     return Status::OK();
1848   }
1849 
1850   if (MoveConstantsPastEnter(optimized_graph, node)) {
1851     graph_modified_ = true;
1852     return Status::OK();
1853   }
1854 
1855   if (SimplifySwitch(optimized_graph, node)) {
1856     graph_modified_ = true;
1857     return Status::OK();
1858   }
1859 
1860   if (SimplifyReduction(optimized_graph, *properties, node)) {
1861     graph_modified_ = true;
1862     return Status::OK();
1863   }
1864 
1865   if (SimplifyReshape(*properties, use_shape_info, node)) {
1866     graph_modified_ = true;
1867     return Status::OK();
1868   }
1869 
1870   bool arithmetic_simplification_succeed = false;
1871   Status simplify_arithmetic_status =
1872       SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
1873                                    node, &arithmetic_simplification_succeed);
1874   if (!simplify_arithmetic_status.ok()) {
1875     return simplify_arithmetic_status;
1876   } else if (arithmetic_simplification_succeed) {
1877     graph_modified_ = true;
1878     return Status::OK();
1879   }
1880 
1881   if (ReduceDivToReciprocalMul(optimized_graph, node)) {
1882     graph_modified_ = true;
1883     return Status::OK();
1884   }
1885 
1886   if (ConstantPushDown(optimized_graph, node)) {
1887     graph_modified_ = true;
1888     return Status::OK();
1889   }
1890 
1891   if (MulConvPushDown(optimized_graph, node, *properties)) {
1892     graph_modified_ = true;
1893     return Status::OK();
1894   }
1895 
1896   if (PartialConstPropThroughIdentityN(node)) {
1897     graph_modified_ = true;
1898     return Status::OK();
1899   }
1900 
1901   if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
1902     graph_modified_ = true;
1903     return Status::OK();
1904   }
1905 
1906   if (PartialConcatConstFolding(optimized_graph, properties, node)) {
1907     graph_modified_ = true;
1908     return Status::OK();
1909   }
1910 
1911   if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
1912     graph_modified_ = true;
1913     return Status::OK();
1914   }
1915 
1916   return Status::OK();
1917 }
1918 
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)1919 bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
1920                                           GraphDef* optimized_graph,
1921                                           NodeDef* node) {
1922   if (node->attr().count("num_split") == 0) return false;
1923   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
1924     ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
1925     return true;
1926   }
1927   if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
1928     ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1929     return true;
1930   }
1931   return false;
1932 }
1933 
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)1934 Status ConstantFolding::RemoveShuffleOrTranspose(
1935     const GraphProperties& properties, bool use_shape_info,
1936     GraphDef* optimized_graph, NodeDef* node, bool* success) {
1937   if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
1938       properties.GetInputProperties(node->name()).size() >= 2) {
1939     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
1940     if (shape.unknown_rank()) {
1941       // Not optimizable.
1942       return Status::OK();
1943     }
1944     const auto& p = properties.GetInputProperties(node->name())[1];
1945     if (TensorShape::IsValid(p.shape()) && p.has_value()) {
1946       Tensor perm(p.dtype(), p.shape());
1947       if (!perm.FromProto(p.value())) {
1948         return errors::InvalidArgument("Cannot parse tensor from proto: ",
1949                                        p.value().DebugString());
1950       }
1951       std::vector<int> permutation;
1952       for (int j = 0; j < perm.NumElements(); ++j) {
1953         if (perm.dtype() == DT_INT64) {
1954           permutation.push_back(perm.vec<int64>()(j));
1955         } else {
1956           permutation.push_back(perm.vec<int>()(j));
1957         }
1958       }
1959       if (permutation.size() != shape.dim_size()) {
1960         // Number of elements in perm should be same as dim_size. Skip if not.
1961         return Status::OK();
1962       }
1963       // The node is replaceable iff
1964       // dim_size == 0 || all dims have size 1 ||
1965       // all dims with > 1 size are not permuted.
1966       bool replaceable = true;
1967       for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
1968         replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
1969       }
1970       if (replaceable) {
1971         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1972         *success = true;
1973         return Status::OK();
1974       }
1975     }
1976   }
1977   *success = false;
1978   return Status::OK();
1979 }
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)1980 bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
1981                                           bool use_shape_info,
1982                                           GraphDef* optimized_graph,
1983                                           NodeDef* node) {
1984   if (use_shape_info && IsRandomShuffle(*node) &&
1985       !properties.GetInputProperties(node->name()).empty()) {
1986     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
1987     // The node is replaceable iff
1988     // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
1989     if (!shape.unknown_rank() &&
1990         (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
1991       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1992       return true;
1993     }
1994   }
1995   return false;
1996 }
1997 
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)1998 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
1999                                       bool use_shape_info,
2000                                       GraphDef* optimized_graph, NodeDef* node,
2001                                       bool* success) {
2002   if (use_shape_info && node->op() == "ReverseV2" &&
2003       properties.GetInputProperties(node->name()).size() >= 2) {
2004     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2005     if (shape.unknown_rank()) {
2006       // Not optimizable.
2007       return Status::OK();
2008     }
2009     const auto& a = properties.GetInputProperties(node->name())[1];
2010     if (TensorShape::IsValid(a.shape()) && a.has_value()) {
2011       Tensor axis(a.dtype(), a.shape());
2012       if (!axis.FromProto(a.value())) {
2013         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2014                                        a.value().DebugString());
2015       }
2016       std::set<int> target_axes;
2017       for (int j = 0; j < axis.NumElements(); ++j) {
2018         // value of axis can be negative.
2019         if (axis.dtype() == DT_INT64) {
2020           target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2021                              shape.dim_size());
2022         } else {
2023           target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2024                              shape.dim_size());
2025         }
2026       }
2027 
2028       // The node is replaceable iff
2029       // unknown_rank == false &&
2030       // (dim_size == 0 || all dims have size 1 ||
2031       //  all dims with > 1 size are not in target_axes)
2032       bool replaceable = !shape.unknown_rank();
2033       for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2034         replaceable &= shape.dim(j).size() == 1 ||
2035                        target_axes.find(j) == target_axes.end();
2036       }
2037       if (replaceable) {
2038         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2039         *success = true;
2040         return Status::OK();
2041       }
2042     }
2043   }
2044   *success = false;
2045   return Status::OK();
2046 }
2047 
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2048 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2049                                       bool use_shape_info,
2050                                       GraphDef* optimized_graph, NodeDef* node,
2051                                       bool* success) {
2052   if (use_shape_info && IsSlice(*node) &&
2053       properties.GetInputProperties(node->name()).size() == 3) {
2054     const auto& input = properties.GetInputProperties(node->name())[0];
2055     const auto& b = properties.GetInputProperties(node->name())[1];
2056     const auto& s = properties.GetInputProperties(node->name())[2];
2057     if (TensorShape::IsValid(b.shape()) && b.has_value() &&
2058         TensorShape::IsValid(s.shape()) && s.has_value()) {
2059       Tensor begin(b.dtype(), b.shape());
2060       if (!begin.FromProto(b.value())) {
2061         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2062                                        b.value().DebugString());
2063       }
2064       Tensor size(s.dtype(), s.shape());
2065       if (!size.FromProto(s.value())) {
2066         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2067                                        s.value().DebugString());
2068       }
2069       // The node is replaceable iff unknown_rank == false &&
2070       // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2071       bool replaceable = !input.shape().unknown_rank();
2072       for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2073         if (begin.dtype() == DT_INT32) {
2074           replaceable &= begin.vec<int>()(j) == 0;
2075         } else {
2076           replaceable &= begin.vec<int64>()(j) == 0;
2077         }
2078         if (size.dtype() == DT_INT32) {
2079           replaceable &= (size.vec<int>()(j) == -1 ||
2080                           size.vec<int>()(j) == input.shape().dim(j).size());
2081         } else {
2082           replaceable &= (size.vec<int64>()(j) == -1 ||
2083                           size.vec<int64>()(j) == input.shape().dim(j).size());
2084         }
2085       }
2086       if (replaceable) {
2087         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2088         *success = true;
2089         return Status::OK();
2090       }
2091     }
2092   }
2093   *success = false;
2094   return Status::OK();
2095 }
2096 
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2097 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2098                                              bool use_shape_info,
2099                                              GraphDef* optimized_graph,
2100                                              NodeDef* node, bool* success) {
2101   if (use_shape_info && IsStridedSlice(*node) &&
2102       properties.GetInputProperties(node->name()).size() == 4) {
2103     TF_RETURN_IF_ERROR(
2104         CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2105     if (node->attr().at("new_axis_mask").i() != 0 ||
2106         node->attr().at("shrink_axis_mask").i() != 0) {
2107       // Skip nodes with new/shrink axis mask, since they involve dimension
2108       // changes.
2109       return Status::OK();
2110     }
2111     const auto& input = properties.GetInputProperties(node->name())[0];
2112     for (int j = 0; j < input.shape().dim_size(); ++j) {
2113       // Skip if input shape is not fully determined.
2114       if (input.shape().dim(j).size() < 0) {
2115         return Status::OK();
2116       }
2117     }
2118     const auto& b = properties.GetInputProperties(node->name())[1];
2119     const auto& e = properties.GetInputProperties(node->name())[2];
2120     const auto& s = properties.GetInputProperties(node->name())[3];
2121     if (TensorShape::IsValid(b.shape()) && b.has_value() &&
2122         TensorShape::IsValid(e.shape()) && e.has_value() &&
2123         TensorShape::IsValid(s.shape()) && s.has_value()) {
2124       Tensor begin(b.dtype(), b.shape());
2125       if (!begin.FromProto(b.value())) {
2126         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2127                                        b.value().DebugString());
2128       }
2129       Tensor end(e.dtype(), e.shape());
2130       if (!end.FromProto(e.value())) {
2131         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2132                                        e.value().DebugString());
2133       }
2134       Tensor strides(s.dtype(), s.shape());
2135       if (!strides.FromProto(s.value())) {
2136         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2137                                        s.value().DebugString());
2138       }
2139       TF_RETURN_IF_ERROR(
2140           CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2141       int begin_mask = node->attr().at("begin_mask").i();
2142       int end_mask = node->attr().at("end_mask").i();
2143       std::set<int> expanded_ellipsis_indices;
2144       int ellipsis_index = -1;
2145       for (int j = 0; j < input.shape().dim_size(); ++j) {
2146         // find the ellipsis_mask. If not found, insert one in the end if
2147         // necessary.
2148         if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2149             (ellipsis_index == -1 && j >= strides.NumElements())) {
2150           ellipsis_index = j;
2151         }
2152         // insert the indices that are immediately after ellipsis_index if
2153         // necessary.
2154         if (ellipsis_index != -1 &&
2155             input.shape().dim_size() >
2156                 strides.NumElements() + j - ellipsis_index) {
2157           expanded_ellipsis_indices.insert(j);
2158         }
2159       }
2160 
2161       // The node is replaceable iff unknown_rank == false &&
2162       // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2163       //  && strides == 1) for all dimensions.
2164       bool replaceable = !input.shape().unknown_rank();
2165       for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2166         if (expanded_ellipsis_indices.find(j) !=
2167             expanded_ellipsis_indices.end()) {
2168           // ellipsis_mask is effective on current dimension.
2169           continue;
2170         }
2171         // when we have ellipsis_mask in between, input.shape().dim_size() will
2172         // be greater than strides.NumElements(), since we will insert
2173         // as many as expanded_ellipsis_indices.size() axes during computation.
2174         // We need to subtract this number from j.
2175         int i = j;
2176         if (ellipsis_index != -1 &&
2177             j >= ellipsis_index + expanded_ellipsis_indices.size()) {
2178           i = j - expanded_ellipsis_indices.size();
2179         }
2180         int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2181                                           : begin.vec<int64>()(i);
2182         int e =
2183             end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2184         int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2185                                             : strides.vec<int64>()(i);
2186         replaceable &=
2187             (begin_mask & 1 << i || b == 0) &&
2188             (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
2189       }
2190       if (replaceable) {
2191         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2192         *success = true;
2193         return Status::OK();
2194       }
2195     }
2196   }
2197   *success = false;
2198   return Status::OK();
2199 }
2200 
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2201 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2202                                      bool use_shape_info,
2203                                      GraphDef* optimized_graph, NodeDef* node,
2204                                      bool* success) {
2205   if (use_shape_info && IsTile(*node) &&
2206       properties.GetInputProperties(node->name()).size() == 2) {
2207     const auto& m = properties.GetInputProperties(node->name())[1];
2208     if (TensorShape::IsValid(m.shape()) && m.has_value()) {
2209       Tensor multiplies(m.dtype(), m.shape());
2210       if (!multiplies.FromProto(m.value())) {
2211         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2212                                        m.value().DebugString());
2213       }
2214       // The node is replaceable iff all values in multiplies are 1.
2215       bool replaceable = true;
2216       if (multiplies.dtype() == DT_INT32) {
2217         for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2218           replaceable &= multiplies.vec<int>()(j) == 1;
2219         }
2220       } else {
2221         for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
2222              ++j) {
2223           replaceable &= multiplies.vec<int64>()(j) == 1;
2224         }
2225       }
2226       if (replaceable) {
2227         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2228         *success = true;
2229         return Status::OK();
2230       }
2231     }
2232   }
2233   *success = false;
2234   return Status::OK();
2235 }
2236 
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2237 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2238                                     bool use_shape_info,
2239                                     GraphDef* optimized_graph, NodeDef* node,
2240                                     bool* success) {
2241   if (use_shape_info && IsPad(*node) &&
2242       properties.GetInputProperties(node->name()).size() >= 2) {
2243     const auto& p = properties.GetInputProperties(node->name())[1];
2244     if (TensorShape::IsValid(p.shape()) && p.has_value()) {
2245       Tensor paddings(p.dtype(), p.shape());
2246       if (!paddings.FromProto(p.value())) {
2247         return errors::InvalidArgument("Cannot parse tensor from proto: ",
2248                                        p.value().DebugString());
2249       }
2250       // The node is replaceable iff all values in paddings are 0.
2251       bool replaceable = true;
2252       // The operation requires it to be int32 value so we don't check for
2253       // 1nt64.
2254       const auto flatten = paddings.flat<int32>();
2255       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2256         replaceable &= flatten(j) == 0;
2257       }
2258       if (replaceable) {
2259         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2260         *success = true;
2261         return Status::OK();
2262       }
2263     }
2264   }
2265   *success = false;
2266   return Status::OK();
2267 }
2268 
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2269 bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2270                                       bool use_shape_info,
2271                                       GraphDef* optimized_graph,
2272                                       NodeDef* node) {
2273   if (use_shape_info && IsSqueeze(*node) &&
2274       !properties.GetInputProperties(node->name()).empty()) {
2275     // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2276     // error to squeeze a dimension that is not 1, so we only need to check
2277     // whether the input has > 1 size for each dimension.
2278     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2279     // The node is replaceable iff
2280     // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2281     bool replaceable = !shape.unknown_rank();
2282     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2283       replaceable &= shape.dim(j).size() > 1;
2284     }
2285     if (replaceable) {
2286       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2287       return true;
2288     }
2289   }
2290   return false;
2291 }
2292 
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2293 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2294   if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
2295       !OptimizedNodeExists(*node, "_const_axis")) {
2296     // Create constant axis node.
2297     Tensor axis_t(DT_INT32, TensorShape({}));
2298     NodeDef* axis_node = optimized_graph->add_node();
2299     axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
2300     const int axis =
2301         node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2302     if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2303         !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
2304              .ok()) {
2305       return false;
2306     }
2307     // Add a control dependency to make sure axis_node is in the right frame.
2308     const string ctrl_dep = ConstantFolding::AddControlDependency(
2309         node->input(0), optimized_graph, node_map_.get());
2310     axis_node->add_input(ctrl_dep);
2311     axis_node->set_device(node->device());
2312     node->set_op("ExpandDims");
2313     if (node->attr().count("axis") != 0) {
2314       node->mutable_attr()->erase("axis");
2315     }
2316     if (node->attr().count("N") != 0) {
2317       node->mutable_attr()->erase("N");
2318     }
2319     (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2320     node->add_input(axis_node->name());
2321     if (node->input_size() > 2) {
2322       node->mutable_input()->SwapElements(1, node->input_size() - 1);
2323     }
2324     return true;
2325   }
2326   return false;
2327 }
2328 
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2329 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2330                                              NodeDef* node) {
2331   if (IsEnter(*node) && node->input_size() > 0) {
2332     if (node->attr().count("is_constant") == 0 ||
2333         !node->attr().at("is_constant").b()) {
2334       return false;
2335     }
2336     const string& node_name = node->name();
2337     const NodeDef* input = node_map_->GetNode(node->input(0));
2338     if (input != nullptr && IsReallyConstant(*input) &&
2339         !OptimizedNodeExists(*input, "_enter")) {
2340       auto fanouts = node_map_->GetOutputs(node_name);
2341       // Find non-constant nodes that consume the output of *node.
2342       std::vector<NodeDef*> consumers;
2343       for (NodeDef* fanout : fanouts) {
2344         if (!IsConstant(*fanout)) {
2345           for (int i = 0; i < fanout->input_size(); ++i) {
2346             if (fanout->input(i) == node_name) {
2347               consumers.push_back(fanout);
2348               break;
2349             }
2350           }
2351         }
2352       }
2353       if (!consumers.empty()) {
2354         NodeDef* new_node = optimized_graph->add_node();
2355         *new_node = *input;
2356         new_node->set_name(OptimizedNodeName(*input, "_enter"));
2357         new_node->set_device(node->device());
2358         new_node->clear_input();
2359         new_node->add_input(AsControlDependency(node_name));
2360         node_map_->AddNode(new_node->name(), new_node);
2361         node_map_->AddOutput(node_name, new_node->name());
2362         for (NodeDef* consumer : consumers) {
2363           for (int i = 0; i < consumer->input_size(); ++i) {
2364             if (NodeName(consumer->input(i)) == node_name) {
2365               node_map_->UpdateInput(consumer->name(), node_name,
2366                                      new_node->name());
2367               consumer->set_input(i, new_node->name());
2368             }
2369           }
2370         }
2371         return true;
2372       }
2373     }
2374   }
2375   return false;
2376 }
2377 
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2378 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2379   if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2380       !OptimizedNodeExists(*node, "_const_false") &&
2381       !OptimizedNodeExists(*node, "_const_true")) {
2382     bool already_optimized = true;
2383     // If the optimization was already applied, the switch would have exactly
2384     // one Identity node consuming each of its outputs, each without any
2385     // non-control outputs.
2386     auto fanouts = node_map_->GetOutputs(node->name());
2387     if (fanouts.size() == 2) {
2388       for (NodeDef* fanout : fanouts) {
2389         if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2390             NumNonControlOutputs(*fanout, *node_map_) > 0) {
2391           already_optimized = false;
2392           break;
2393         }
2394       }
2395     }
2396     Tensor false_t(DT_BOOL, TensorShape({}));
2397     Tensor true_t(DT_BOOL, TensorShape({}));
2398     // Make sure we don't proceed if this switch node was already optimized.
2399     if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2400         SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2401       // Copy the set of consumers of the switch as they will be manipulated
2402       // below.
2403       const std::set<NodeDef*>& consumer_set =
2404           node_map_->GetOutputs(node->name());
2405       std::vector<NodeDef*> consumers(consumer_set.begin(), consumer_set.end());
2406       std::sort(consumers.begin(), consumers.end(),
2407                 [](const NodeDef* n1, const NodeDef* n2) {
2408                   return n1->name() < n2->name();
2409                 });
2410       // Create constant false & true nodes.
2411       NodeDef* false_node = optimized_graph->add_node();
2412       false_node->set_name(OptimizedNodeName(*node, "_const_false"));
2413       if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
2414                .ok()) {
2415         return false;
2416       }
2417       false_node->set_device(node->device());
2418 
2419       NodeDef* true_node = optimized_graph->add_node();
2420       true_node->set_name(OptimizedNodeName(*node, "_const_true"));
2421       if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
2422                .ok()) {
2423         return false;
2424       }
2425       true_node->set_device(node->device());
2426 
2427       // Add controls from the switch ports to the constants, and connect the
2428       // constants to the original switch outputs.
2429       const string false_port = node->name();
2430       const string true_port = strings::StrCat(node->name(), ":1");
2431       const string false_ctrl_dep =
2432           AddControlDependency(false_port, optimized_graph, node_map_.get());
2433       false_node->add_input(false_ctrl_dep);
2434       const string true_ctrl_dep =
2435           AddControlDependency(true_port, optimized_graph, node_map_.get());
2436       true_node->add_input(true_ctrl_dep);
2437 
2438       node_map_->AddNode(false_node->name(), false_node);
2439       node_map_->AddNode(true_node->name(), true_node);
2440       node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2441       node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2442 
2443       for (NodeDef* consumer : consumers) {
2444         for (int i = 0; i < consumer->input_size(); ++i) {
2445           const string& input = consumer->input(i);
2446           if (input == false_port) {
2447             consumer->set_input(i, false_node->name());
2448             node_map_->UpdateInput(consumer->name(), false_port,
2449                                    false_node->name());
2450           } else if (input == true_port) {
2451             consumer->set_input(i, true_node->name());
2452             node_map_->UpdateInput(consumer->name(), true_port,
2453                                    true_node->name());
2454           }
2455         }
2456       }
2457       return true;
2458     }
2459   }
2460   return false;
2461 }
2462 
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2463 bool ConstantFolding::IsReductionCandidateForSimplification(
2464     const NodeDef& node, const GraphProperties& properties,
2465     TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2466     bool* is_single_element_op) const {
2467   // Ensure its an appropriate Reduce node.
2468   if (!IsReduction(node) || node.input_size() < 2) {
2469     return false;
2470   }
2471   // Ensure that the axes to reduce by are constant.
2472   NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2473   if (!IsReallyConstant(*reductions_indices)) {
2474     return false;
2475   }
2476 
2477   // Get the properties of the input & output tensors and check if they both
2478   // contain a single element.
2479   if (!properties.HasInputProperties(node.name()) ||
2480       !properties.HasOutputProperties(node.name())) {
2481     return false;
2482   }
2483   const auto& input_props = properties.GetInputProperties(node.name())[0];
2484   const auto& output_props = properties.GetOutputProperties(node.name())[0];
2485   if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2486       !output_props.has_shape() || output_props.shape().unknown_rank()) {
2487     return false;
2488   }
2489   *input_tensor_shape = input_props.shape();
2490   *output_tensor_shape = output_props.shape();
2491   for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2492     if (input_tensor_shape->dim(i).size() < 0) {
2493       return false;
2494     }
2495   }
2496   for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2497     if (output_tensor_shape->dim(i).size() < 0) {
2498       return false;
2499     }
2500   }
2501   const int input_num_elements =
2502       TensorShape(*input_tensor_shape).num_elements();
2503   const int output_num_elements =
2504       TensorShape(*output_tensor_shape).num_elements();
2505   *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2506 
2507   return true;
2508 }
2509 
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2510 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2511     const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2512     const TensorVector& reduction_indices_vector) const {
2513   int output_size = reduction_indices_vector[0]->NumElements();
2514   if (output_size == 0) {
2515     return true;
2516   }
2517 
2518   if (!keep_dims) {
2519     return false;
2520   }
2521   bool simplifiable = true;
2522   for (int i = 0; i < output_size; ++i) {
2523     int64 dim;
2524     if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2525       dim = reduction_indices_vector[0]->flat<int32>()(i);
2526     } else {
2527       dim = reduction_indices_vector[0]->flat<int64>()(i);
2528     }
2529     if (dim < 0) {
2530       dim += input_shape.dim_size();
2531     }
2532     if (dim < 0 || dim >= input_shape.dim_size() ||
2533         input_shape.dim(dim).size() != 1) {
2534       simplifiable = false;
2535       break;
2536     }
2537   }
2538   return simplifiable;
2539 }
2540 
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2541 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2542                                         const GraphProperties& properties,
2543                                         NodeDef* node) {
2544   bool is_single_element_op = false;
2545   TensorShapeProto input_tensor_shape, output_tensor_shape;
2546   if (!IsReductionCandidateForSimplification(
2547           *node, properties, &input_tensor_shape, &output_tensor_shape,
2548           &is_single_element_op)) {
2549     return false;
2550   }
2551 
2552   // Get the reduction indices.
2553   string reduction_indices_input = node->input(1);
2554   NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2555   TensorVector reduction_indices_vector;
2556   auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2557     for (const auto& out : reduction_indices_vector) {
2558       delete out.tensor;
2559     }
2560   });
2561   if (!EvaluateNode(*reduction_indices, TensorVector(),
2562                     &reduction_indices_vector)
2563            .ok() ||
2564       reduction_indices_vector.size() != 1) {
2565     return false;
2566   }
2567 
2568   bool keep_dims =
2569       node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2570   bool simplifiable_to_reshape =
2571       is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2572   bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2573       *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2574 
2575   if (simplifiable_to_reshape) {
2576     // Const node to output shape.
2577     const int new_num_dimensions = output_tensor_shape.dim_size();
2578     Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2579     for (int i = 0; i < new_num_dimensions; i++) {
2580       tensor.flat<int>()(i) = 1;
2581     }
2582     TensorValue shape_value(&tensor);
2583     NodeDef* shape_node = optimized_graph->add_node();
2584     if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2585                        shape_node)
2586              .ok()) {
2587       return false;
2588     }
2589     shape_node->set_device(node->device());
2590     node_map_->AddNode(shape_node->name(), shape_node);
2591     // Control dependency to ensure shape_node is in the correct frame.
2592     shape_node->add_input(AsControlDependency(reduction_indices_input));
2593     node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2594     // Optimize node to Reshape.
2595     node->set_op("Reshape");
2596     node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2597     node->set_input(1, shape_node->name());
2598     node->mutable_attr()->erase("keep_dims");
2599     node->mutable_attr()->erase("Tidx");
2600     AttrValue attr_type_indices;
2601     attr_type_indices.set_type(DT_INT32);
2602     (*node->mutable_attr())["Tshape"] = attr_type_indices;
2603     return true;
2604   } else if (simplifiable_to_identity) {
2605     // Replace the reduction node with an identity node, that can be further
2606     // optimized by the model pruner.
2607     DataType output_type;
2608     if (node->attr().count("T") != 0) {
2609       output_type = node->attr().at("T").type();
2610     } else {
2611       // This is an 'any' or 'all' reduction. The output is always boolean.
2612       output_type = DT_BOOL;
2613     }
2614     node->set_op("Identity");
2615     node->clear_attr();
2616     (*node->mutable_attr())["T"].set_type(output_type);
2617     *node->mutable_input(1) = AsControlDependency(node->input(1));
2618     return true;
2619   }
2620   return false;
2621 }
2622 
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2623 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2624                                       bool use_shape_info, NodeDef* node) {
2625   if (!use_shape_info || node->attr().count("T") == 0 ||
2626       !IsSimplifiableReshape(*node, properties)) {
2627     return false;
2628   }
2629   DataType output_type = node->attr().at("T").type();
2630   node->set_op("Identity");
2631   node->clear_attr();
2632   (*node->mutable_attr())["T"].set_type(output_type);
2633   *node->mutable_input(1) = AsControlDependency(node->input(1));
2634   return true;
2635 }
2636 
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2637 Status ConstantFolding::SimplifyArithmeticOperations(
2638     const GraphProperties& properties, bool use_shape_info,
2639     GraphDef* optimized_graph, NodeDef* node, bool* success) {
2640   *success = false;
2641   const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
2642   const bool is_matmul = IsMatMul(*node);
2643   const bool is_quantized_matmul = IsQuantizedMatMul(*node);
2644   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2645   const bool is_sub = IsSub(*node);
2646   const bool is_any_div = IsAnyDiv(*node);
2647   // Simplify arithmetic operations with ones or zeros.
2648   if (use_shape_info &&
2649       (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2650       properties.HasInputProperties(node->name()) &&
2651       properties.HasOutputProperties(node->name())) {
2652     const NodeDef* x = node_map_->GetNode(node->input(0));
2653     const NodeDef* y = node_map_->GetNode(node->input(1));
2654     if (x == nullptr || y == nullptr) {
2655       return errors::InvalidArgument("Invalid inputs to node: ",
2656                                      node->DebugString());
2657     }
2658     const TensorShapeProto& output_shape =
2659         properties.GetOutputProperties(node->name())[0].shape();
2660 
2661     // Simplify element-wise multiplication by ones or addition/subtraction
2662     // of zeros.
2663     const TensorShapeProto& y_shape =
2664         properties.GetInputProperties(node->name())[1].shape();
2665     const bool x_is_zero = IsZeros(*x);
2666     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2667     const bool y_matches_output_shape =
2668         ShapesSymbolicallyEqual(output_shape, y_shape);
2669     if (y_matches_output_shape &&
2670         ((is_mul && x_is_one) || (is_add && x_is_zero))) {
2671       // 1 * y = y or 0 + y = y.
2672       ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2673       *success = true;
2674       return Status::OK();
2675     }
2676 
2677     if (y_matches_output_shape && (is_sub && x_is_zero)) {
2678       // Replace 0 - y with Neg(y).
2679       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2680       *success = true;
2681       return Status::OK();
2682     }
2683 
2684     // Replace 1 / y with Reciprocal op.
2685     if (y_matches_output_shape && is_any_div && x_is_one) {
2686       TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2687       DataType type = node->attr().at("T").type();
2688       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2689         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2690         *success = true;
2691         return Status::OK();
2692       }
2693     }
2694 
2695     const TensorShapeProto& x_shape =
2696         properties.GetInputProperties(node->name())[0].shape();
2697     const bool y_is_zero = IsZeros(*y);
2698     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2699     const bool x_matches_output_shape =
2700         ShapesSymbolicallyEqual(output_shape, x_shape);
2701     if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
2702                                    ((is_add || is_sub) && y_is_zero))) {
2703       // x * 1 = x or x / 1 = x or x +/- 0 = x
2704       ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
2705       *success = true;
2706       return Status::OK();
2707     }
2708 
2709     // x OR true = true OR y = true.
2710     bool updated_graph = false;
2711     const PartialTensorShape shp(output_shape);
2712     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
2713       bool replace_succeed = false;
2714       Status replace_op_status = ReplaceOperationWithConstant(
2715           1, properties, output_shape, node, optimized_graph, &replace_succeed);
2716       if (!replace_op_status.ok()) {
2717         return replace_op_status;
2718       } else if (replace_succeed) {
2719         updated_graph = true;
2720       }
2721     }
2722 
2723     // Simplify multiplication and matmul by zeros.
2724     // Also optimize zeros divided by a tensor, but only if we are in
2725     // aggressive mode, since we might get rid of divisions by zero.
2726     const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
2727     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
2728     if ((x_is_zero || y_is_zero) &&
2729         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
2730       if (shp.IsFullyDefined()) {
2731         bool replace_succeed = false;
2732         Status replace_op_status =
2733             ReplaceOperationWithConstant(0, properties, output_shape, node,
2734                                          optimized_graph, &replace_succeed);
2735         if (!replace_op_status.ok()) {
2736           return replace_op_status;
2737         } else if (replace_succeed) {
2738           if (is_quantized_matmul) {
2739             TF_RETURN_IF_ERROR(
2740                 AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
2741           }
2742           *success = true;
2743           return Status::OK();
2744         }
2745       }
2746       // Even if an input shape is only partially known, we may known that it
2747       // matches the output shape and thus forward the corresponding zero
2748       // input.
2749       if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
2750         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2751         *success = true;
2752         return Status::OK();
2753       } else if (is_mul && y_is_zero && y_matches_output_shape) {
2754         ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2755         *success = true;
2756         return Status::OK();
2757       }
2758     }
2759     if (updated_graph) {
2760       *success = true;
2761       return Status::OK();
2762     }
2763   }
2764   *success = false;
2765   return Status::OK();
2766 }
2767 
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)2768 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
2769                                                NodeDef* node) {
2770   // Strength reduce floating point division by a constant Div(x, const) to
2771   // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
2772   // will be constant folded to Mul(x, 1.0/const).
2773   if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
2774     const string& const_input = node->input(1);
2775     const NodeDef* denom = node_map_->GetNode(const_input);
2776     CHECK(denom != nullptr);
2777     if (!IsReallyConstant(*denom)) {
2778       return false;
2779     }
2780     if (node->attr().count("T") == 0) {
2781       return false;
2782     }
2783     DataType type = node->attr().at("T").type();
2784     if (IsDiv(*node) &&
2785         !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
2786       return false;
2787     }
2788     // Insert new reciprocal op and change node from Div to Mul.
2789     NodeDef* reciprocal_node = optimized_graph->add_node();
2790     reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
2791     reciprocal_node->set_op("Reciprocal");
2792     reciprocal_node->set_device(node->device());
2793     node->set_op("Mul");
2794     // Re-wire inputs and outputs.
2795     reciprocal_node->add_input(const_input);
2796     (*reciprocal_node->mutable_attr())["T"].set_type(type);
2797     node->set_input(1, reciprocal_node->name());
2798     node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
2799     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
2800     return true;
2801   }
2802 
2803   return false;
2804 }
2805 
ConstantPushDown(GraphDef * optimized_graph,NodeDef * node)2806 bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
2807                                        NodeDef* node) {
2808   // Consider the transformation
2809   //
2810   //                      +                +       = parent
2811   //                     / \              / \
2812   //                    C   +    -- >    X   +     = children
2813   //                       / \              / \
2814   //                      X   Y            C   Y   = leaves
2815   //
2816   // where C is constant and X is non-constant, and '+' denotes an
2817   // associative and commutative operator like addition or multiplication.
2818   // This optimization pushes constants down in the tree to canonicalize it.
2819   // Moreoever, in cases where the child node has a second constant input Y
2820   // we will create a leaf node that can be folded, e.g.
2821   //
2822   //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
2823   //
2824   // TODO(rmlarsen): Handle non-associative/non-commutative operators like
2825   // subtraction and division, as well as mixed subtraction/addition,
2826   // division/multiplication.
2827   // Don't touch BiasAdd since they can't handle vectors as their first
2828   // inputs.
2829   if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
2830       NumNonControlInputs(*node) == 2) {
2831     NodeDef* left_child = node_map_->GetNode(node->input(0));
2832     NodeDef* right_child = node_map_->GetNode(node->input(1));
2833     // One child must be constant, and the other the same op as the parent.
2834     if (node->op() != left_child->op() && node->op() != right_child->op()) {
2835       return false;
2836     }
2837     const bool left_child_is_constant = IsReallyConstant(*left_child);
2838     const bool right_child_is_constant = IsReallyConstant(*right_child);
2839     if (!left_child_is_constant && !right_child_is_constant) {
2840       return false;
2841     }
2842     if (node->device() != left_child->device() ||
2843         node->device() != right_child->device()) {
2844       return false;
2845     }
2846     NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
2847     NodeDef* const_child_node =
2848         left_child_is_constant ? left_child : right_child;
2849     // Make sure that it is safe to change the value of the child node->
2850     if (op_child_node->input_size() < 2 ||
2851         nodes_to_preserve_.find(op_child_node->name()) !=
2852             nodes_to_preserve_.end() ||
2853         NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
2854       return false;
2855     }
2856 
2857     // Identify the nodes to swap.
2858     NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
2859     NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
2860     const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
2861     const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
2862     if (left_leaf_is_constant && right_leaf_is_constant) {
2863       // Child is already foldable, leave it alone.
2864       return false;
2865     }
2866     const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
2867     const int parent_const_input = left_child_is_constant ? 0 : 1;
2868     const auto& child_output = node_map_->GetOutputs(op_child_node->name());
2869     if (child_output.find(const_child_node) != child_output.end()) {
2870       // If there is a control edge from the child op to C, the transformation
2871       // would create a cycle in the graph. We know that it must be a control
2872       // edge. We can replace such a control edge with a control edge from A
2873       // to C.
2874       CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
2875                                     optimized_graph, node_map_.get()));
2876       string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
2877                                                       : op_child_node->input(1);
2878       MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
2879                            node_map_.get());
2880     }
2881 
2882     // Swap the constant child with a non-constant leaf node.
2883     node_map_->UpdateInput(node->name(), node->input(parent_const_input),
2884                            op_child_node->input(non_const_leaf_input));
2885     node_map_->UpdateInput(op_child_node->name(),
2886                            op_child_node->input(non_const_leaf_input),
2887                            node->input(parent_const_input));
2888     std::swap(*node->mutable_input(parent_const_input),
2889               *op_child_node->mutable_input(non_const_leaf_input));
2890     return true;
2891   }
2892   return false;
2893 }
2894 
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)2895 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
2896                                       const GraphProperties& properties) {
2897   // Push down multiplication on ConvND.
2898   //                       *                  ConvND
2899   //                     /   \                /    \
2900   //                 ConvND  C2    -- >      X      *
2901   //                  / \                          / \
2902   //                 X  C1                       C1  C2
2903   //
2904   // where C1 and C2 are constants and X is non-constant.
2905   if (!IsMul(*node) || NumNonControlInputs(*node) != 2) return false;
2906 
2907   NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
2908   NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
2909   // One child must be constant, and the second must be Conv op.
2910   const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
2911   const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
2912   if (!left_child_is_constant && !right_child_is_constant) {
2913     return false;
2914   }
2915   NodeDef* conv_node =
2916       left_child_is_constant ? mul_right_child : mul_left_child;
2917   if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
2918     return false;
2919   }
2920   if (node->device() != mul_left_child->device() ||
2921       node->device() != mul_right_child->device()) {
2922     return false;
2923   }
2924 
2925   // Make sure that it is safe to change the value of the convolution
2926   // output.
2927   if (conv_node->input_size() < 2 ||
2928       NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
2929       nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
2930     return false;
2931   }
2932 
2933   // Identify the nodes to swap.
2934   NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
2935   NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
2936   const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
2937   const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
2938   if (!conv_left_is_constant && !conv_right_is_constant) {
2939     // At least one of the convolution inputs should be constant.
2940     return false;
2941   }
2942   if (conv_left_is_constant && conv_right_is_constant) {
2943     // Leverage regular constant folding to handle this.
2944     return false;
2945   }
2946   const auto& mul_props = properties.GetOutputProperties(node->name());
2947   const auto& conv_props = properties.GetOutputProperties(conv_node->name());
2948   if (mul_props.empty() || conv_props.empty()) {
2949     return false;
2950   }
2951   const auto& mul_shape = mul_props[0].shape();
2952   const auto& conv_shape = conv_props[0].shape();
2953   if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
2954     return false;
2955   }
2956 
2957   const auto& input_props = properties.GetInputProperties(conv_node->name());
2958   if (input_props.size() < 2) {
2959     return false;
2960   }
2961   const auto& filter_shape = input_props[1].shape();
2962 
2963   NodeDef* const_node =
2964       left_child_is_constant ? mul_left_child : mul_right_child;
2965   const auto& const_props = properties.GetOutputProperties(const_node->name());
2966   if (const_props.empty()) {
2967     return false;
2968   }
2969   const auto& const_shape = const_props[0].shape();
2970   if (!IsValidConstShapeForMulConvPushDown(
2971           conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
2972     return false;
2973   }
2974 
2975   string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
2976   if (node_map_->NodeExists(mul_new_name)) {
2977     return false;
2978   }
2979   // Make sure we don't introduce loops in the graph by removing control
2980   // dependencies from the conv2d node to c2.
2981   string conv_const_input =
2982       conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
2983   if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
2984                               node_map_.get())) {
2985     // Add a control dep from c1 to c2 to ensure c2 is in the right frame
2986     MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
2987                          node_map_.get());
2988   }
2989 
2990   conv_node->set_name(node->name());
2991   node->set_name(mul_new_name);
2992   if (conv_left_is_constant) {
2993     node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
2994     conv_node->set_input(0, mul_new_name);
2995   } else {
2996     node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
2997     conv_node->set_input(1, mul_new_name);
2998   }
2999   NodeDef* conv_const_node =
3000       conv_left_is_constant ? conv_left_child : conv_right_child;
3001   if (left_child_is_constant) {
3002     node->set_input(1, conv_const_node->name());
3003   } else {
3004     node->set_input(0, conv_const_node->name());
3005   }
3006   node_map_->AddNode(mul_new_name, node);
3007 
3008   return true;
3009 }
3010 
PartialConstPropThroughIdentityN(NodeDef * node)3011 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3012   // Partial constant propagation through IdentityN.
3013   if ((IsIdentityN(*node) || IsIdentityNSingleInput(*node)) &&
3014       NumNonControlInputs(*node) > 0) {
3015     const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
3016     const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
3017     bool updated_graph = false;
3018     for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3019       const string& input = node->input(input_idx);
3020       if (IsControlInput(input)) {
3021         break;
3022       }
3023       const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3024       if (input_node == nullptr) {
3025         LOG(ERROR) << "Bad input: " << input;
3026         break;
3027       }
3028       // Forward constant inputs to outputs and add a control dependency on
3029       // the IdentityN node.
3030       if (IsReallyConstant(*input_node)) {
3031         // Update each consumer.
3032         for (NodeDef* consumer : consumers) {
3033           bool add_dep = false;
3034           for (int consumer_input_idx = 0;
3035                consumer_input_idx < consumer->input_size();
3036                ++consumer_input_idx) {
3037             const string& consumer_input = consumer->input(consumer_input_idx);
3038             if (IsControlInput(consumer_input)) {
3039               break;
3040             }
3041             int output_idx;
3042             const string input_node_name =
3043                 ParseNodeName(consumer_input, &output_idx);
3044             if (input_node_name == node->name() && output_idx == input_idx) {
3045               consumer->set_input(consumer_input_idx, input);
3046               // We will keep the input from IdentityN through a control
3047               // dependency, so we only need to add the consumer as an output
3048               // for the constant input node.
3049               node_map_->AddOutput(NodeName(input), consumer->name());
3050               add_dep = true;
3051             }
3052           }
3053           if (add_dep) {
3054             consumer->add_input(AsControlDependency(node->name()));
3055             updated_graph = true;
3056           }
3057         }
3058       }
3059     }
3060 
3061     if (updated_graph) {
3062       for (NodeDef* consumer : consumers) {
3063         DedupControlInputs(consumer);
3064       }
3065       return true;
3066     }
3067   }
3068   return false;
3069 }
3070 
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3071 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3072                                                  GraphProperties* properties,
3073                                                  NodeDef* node) {
3074   // Partial constant folding for associative operators:
3075   // Split AddN/AccumulateNV2 to enable partial
3076   // folding of ops when more than one but not all inputs are constant.
3077   // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3078   // addition is commutative.
3079   const int num_non_control_inputs = NumNonControlInputs(*node);
3080   if (IsAggregate(*node) && IsCommutative(*node) &&
3081       num_non_control_inputs > 2) {
3082     const int num_control_inputs = node->input_size() - num_non_control_inputs;
3083     std::vector<int> const_inputs;
3084     std::vector<int> nonconst_inputs;
3085     for (int i = 0; i < node->input_size(); ++i) {
3086       const string& input = node->input(i);
3087       const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3088       CHECK(input_node != nullptr) << input;
3089       if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3090         const_inputs.push_back(i);
3091       } else {
3092         // Non-const and control inputs.
3093         nonconst_inputs.push_back(i);
3094       }
3095     }
3096     // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3097     // a fake node that cannot be constant folded by itself.
3098     if (const_inputs.size() == num_non_control_inputs &&
3099         node->op() == "AccumulateNV2") {
3100       node->set_op("AddN");
3101       node->mutable_attr()->erase("shape");
3102       return true;
3103     }
3104     const string new_node_name = OptimizedNodeName(
3105         *node, strings::StrCat("_partial_split_", const_inputs.size()));
3106     if (1 < const_inputs.size() &&
3107         const_inputs.size() < num_non_control_inputs &&
3108         !node_map_->NodeExists(new_node_name)) {
3109       NodeDef* added_node = optimized_graph->add_node();
3110       *added_node = *node;
3111       // Always use AddN for the constant node, since AccumulateNV2 is a fake
3112       // node that cannot be constant folded, since it does not have a kernel.
3113       added_node->set_op("AddN");
3114       added_node->mutable_attr()->erase("shape");
3115       added_node->set_name(new_node_name);
3116       node_map_->AddNode(added_node->name(), added_node);
3117       added_node->clear_input();
3118       for (int i : const_inputs) {
3119         added_node->add_input(node->input(i));
3120         node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3121                                 added_node->name());
3122       }
3123 
3124       // Overwrite the first const input with the added node.
3125       node->set_input(const_inputs[0], added_node->name());
3126       node_map_->AddOutput(added_node->name(), node->name());
3127       nonconst_inputs.push_back(const_inputs[0]);
3128       // Compact the remaining inputs to the original node.
3129       std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3130       int idx = 0;
3131       for (int i : nonconst_inputs) {
3132         if (idx != i) {
3133           node->set_input(idx, node->input(i));
3134         }
3135         ++idx;
3136       }
3137       node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3138                                             const_inputs.size() - 1);
3139       (*node->mutable_attr())["N"].set_i(node->input_size() -
3140                                          num_control_inputs);
3141       properties->ClearInputProperties(node->name());
3142       (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3143       return true;
3144     }
3145   }
3146   return false;
3147 }
3148 
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3149 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3150                                                 GraphProperties* properties,
3151                                                 NodeDef* node) {
3152   // Partial constant folding for Concat which is not commutative, so
3153   // we have to preserve order and can only push consecutive runs of constant
3154   // inputs into sub-nodes.
3155   const int num_non_control_inputs = NumNonControlInputs(*node);
3156   if (IsConcat(*node) && num_non_control_inputs > 3 &&
3157       node->name().rfind("_partial_split_") == string::npos) {
3158     int axis_arg = -1;
3159     int begin = 0;
3160     int end = num_non_control_inputs;
3161     if (node->op() == "Concat") {
3162       begin = 1;
3163       axis_arg = 0;
3164     } else if (node->op() == "ConcatV2") {
3165       end = num_non_control_inputs - 1;
3166       axis_arg = num_non_control_inputs - 1;
3167     } else {
3168       return false;
3169     }
3170 
3171     const NodeDef* axis_arg_node =
3172         node_map_->GetNode(NodeName(node->input(axis_arg)));
3173     if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
3174       // We cannot constant fold Concat unless we the axis argument is
3175       // constant. Skip node.
3176       return false;
3177     }
3178 
3179     // We search for consecutive runs of constant inputs in the range
3180     // [begin:end[ and push then down into child nodes.
3181     std::vector<std::pair<int, int>> constant_input_runs;
3182     int first = begin;
3183     int last = begin;
3184     while (last < end) {
3185       while (first < end && !IsReallyConstant(*node_map_->GetNode(
3186                                 NodeName(node->input(first))))) {
3187         ++first;
3188       }
3189       // Invariant: node[first] is constant || first >= end.
3190       last = first + 1;
3191       while (last < end && IsReallyConstant(*node_map_->GetNode(
3192                                NodeName(node->input(last))))) {
3193         ++last;
3194       }
3195       // Invariant: node[last] is not constant || last >= end
3196       // Discard intervals shorter than 2 elements.
3197       if (first < end && (last - first) > 1) {
3198         constant_input_runs.emplace_back(first, last);
3199       }
3200       first = last;
3201     }
3202 
3203     // Skip if all inputs are constant, and let constant folding take over.
3204     if (constant_input_runs.size() == 1 &&
3205         constant_input_runs[0].first == begin &&
3206         constant_input_runs[0].second == end) {
3207       return false;
3208     }
3209     std::set<int> inputs_to_delete;
3210     for (auto interval : constant_input_runs) {
3211       // Push the constant inputs in the interval to a child node than can be
3212       // constant folded.
3213       const string new_node_name = OptimizedNodeName(
3214           *node, strings::StrCat("_partial_split_", interval.first));
3215       if (node_map_->NodeExists(new_node_name)) {
3216         break;
3217       }
3218       NodeDef* added_node = optimized_graph->add_node();
3219       *added_node = *node;
3220       added_node->set_name(new_node_name);
3221       node_map_->AddNode(added_node->name(), added_node);
3222       added_node->clear_input();
3223       for (int i = interval.first; i < interval.second; ++i) {
3224         added_node->add_input(node->input(i));
3225         node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3226                                 added_node->name());
3227         if (i != interval.first) {
3228           inputs_to_delete.insert(i);
3229         }
3230       }
3231       added_node->add_input(node->input(axis_arg));
3232       (*added_node->mutable_attr())["N"].set_i(interval.second -
3233                                                interval.first);
3234       node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3235 
3236       // Overwrite the first constant input with the result of the added
3237       // child node.
3238       node->set_input(interval.first, added_node->name());
3239       node_map_->AddOutput(added_node->name(), node->name());
3240     }
3241     if (!constant_input_runs.empty()) {
3242       if (!inputs_to_delete.empty()) {
3243         // Fix up the inputs to the original node.
3244         std::vector<string> tmp(node->input().begin(), node->input().end());
3245         node->clear_input();
3246         for (int i = 0; i < tmp.size(); ++i) {
3247           if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3248             node->add_input(tmp[i]);
3249           }
3250         }
3251         (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3252         properties->ClearInputProperties(node->name());
3253       }
3254       return true;
3255     }
3256   }
3257   return false;
3258 }
3259 
MergeConcat(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)3260 bool ConstantFolding::MergeConcat(const GraphProperties& properties,
3261                                   bool use_shape_info,
3262                                   GraphDef* optimized_graph, NodeDef* node) {
3263   // We only optimize for ConcatV2.
3264   int axis;
3265   if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
3266       nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3267       node_map_->GetOutputs(node->name()).size() != 1) {
3268     return false;
3269   }
3270 
3271   NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3272   int parent_axis;
3273   if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
3274     return false;
3275   }
3276 
3277   const int index = NumNonControlInputs(*node) - 1;
3278   auto inputs = parent->input();
3279   parent->clear_input();
3280   for (int i = 0; i < inputs.size(); ++i) {
3281     if (IsSameInput(inputs.Get(i), node->name())) {
3282       for (int j = 0; j < node->input_size(); ++j) {
3283         if (j < index) {
3284           // Input tensors (non axis), add to input list of parent.
3285           parent->add_input(node->input(j));
3286           node_map_->RemoveOutput(node->input(j), node->name());
3287           node_map_->AddOutput(node->input(j), parent->name());
3288         }
3289         // Skip j == index, which means axis tensor.
3290         if (j > index) {
3291           // Control Dependencies, push back to inputs so they can be forwarded
3292           // to parent.
3293           *inputs.Add() = node->input(j);
3294         }
3295       }
3296     } else {
3297       parent->add_input(inputs.Get(i));
3298     }
3299   }
3300   node->clear_input();
3301   node->set_op("NoOp");
3302   node->clear_attr();
3303   node_map_->RemoveNode(node->name());
3304   (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3305 
3306   return true;
3307 }
3308 
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3309 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3310     NodeDef* node, GraphDef* optimized_graph) {
3311   auto add_quantized_out = [this, node, optimized_graph](
3312                                const string& out_const_name, int index) {
3313     NodeDef* out_node = optimized_graph->add_node();
3314     Tensor value(DT_FLOAT, TensorShape({}));
3315     const bool is_min = index == 1;
3316     const DataType type_attr = node->attr().at("dtype").type();
3317 
3318     value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3319                                     : QuantizedTypeMaxAsFloat(type_attr);
3320     TF_RETURN_IF_ERROR(
3321         CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3322     node_map_->AddNode(out_const_name, out_node);
3323     out_node->set_device(node->device());
3324 
3325     // Copy all inputs from node.
3326     out_node->mutable_input()->CopyFrom(node->input());
3327     for (const string& input : out_node->input()) {
3328       node_map_->AddOutput(NodeName(input), out_const_name);
3329     }
3330 
3331     // Update output nodes consuming node:index to new const node.
3332     string old_input = absl::StrCat(node->name(), ":", index);
3333     int old_node_count = 0;
3334     auto outputs = node_map_->GetOutputs(node->name());
3335     for (const auto& output : outputs) {
3336       for (int i = 0; i < output->input_size(); ++i) {
3337         if (output->input(i) == old_input) {
3338           output->set_input(i, out_const_name);
3339           node_map_->AddOutput(out_const_name, output->name());
3340         } else if (NodeName(output->input(i)) == node->name()) {
3341           ++old_node_count;
3342         }
3343       }
3344       if (old_node_count == 0) {
3345         node_map_->RemoveOutput(node->name(), output->name());
3346       }
3347     }
3348 
3349     return Status::OK();
3350   };
3351   const string min_out_const_name =
3352       OptimizedNodeName(*node, "-quantized_matmul_min_out");
3353   const string max_out_const_name =
3354       OptimizedNodeName(*node, "-quantized_matmul_max_out");
3355   if (node_map_->GetNode(min_out_const_name) == nullptr &&
3356       node_map_->GetNode(max_out_const_name) == nullptr) {
3357     TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3358     TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3359   } else {
3360     return errors::Internal(absl::Substitute(
3361         "Can't create Const for QuantizedMatMul min_out/max_out of "
3362         "node '$0' because of node name conflict",
3363         node->name()));
3364   }
3365   return Status::OK();
3366 }
3367 
RunOptimizationPass(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3368 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3369                                             const GrapplerItem& item,
3370                                             GraphDef* optimized_graph) {
3371   node_map_.reset(new NodeMap(graph_));
3372   nodes_whitelist_.clear();
3373   // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3374   // has a single fanout, it would be rewritten as a constant with the same
3375   // node name, and therefore users are still able to fetch it. This is not
3376   // the case if the node has multiple fanouts, and constant folding would
3377   // replace the node with multiple constants (each for one fanout) with
3378   // new names, and as a result users would not be able to fetch the node any
3379   // more with the original node name.
3380   for (const auto& fetch : item.fetch) {
3381     const NodeDef* fetch_node = node_map_->GetNode(fetch);
3382     if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3383       nodes_whitelist_.insert(fetch_node->name());
3384     }
3385   }
3386 
3387   GraphProperties properties(item);
3388   // It's possible to feed a placeholder with a tensor of any shape: make sure
3389   // that the shape inference deals with this conservatively unless we're in
3390   // aggressive mode.
3391   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3392   Status s = properties.InferStatically(assume_valid_feeds);
3393   const bool can_use_shape_info = s.ok();
3394 
3395   if (can_use_shape_info) {
3396     TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3397     TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3398   }
3399   absl::flat_hash_set<string> nodes_to_not_simplify;
3400   TF_RETURN_IF_ERROR(FoldGraph(optimized_graph, &nodes_to_not_simplify));
3401   node_map_.reset(new NodeMap(optimized_graph));
3402   TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3403                                    &properties, &nodes_to_not_simplify));
3404 
3405   return Status::OK();
3406 }
3407 
3408 namespace {
CompressConstants(GraphDef * graph)3409 Status CompressConstants(GraphDef* graph) {
3410   for (int i = 0; i < graph->node_size(); ++i) {
3411     NodeDef* node = graph->mutable_node(i);
3412     if ((IsConstant(*node) || IsHostConstant(*node)) &&
3413         HasNodeAttr(*node, "value")) {
3414       AttrValue& attr_val = (*node->mutable_attr())["value"];
3415       tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
3416     }
3417   }
3418   return Status::OK();
3419 }
3420 }  // namespace
3421 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3422 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3423                                  GraphDef* optimized_graph) {
3424   // TensorFlow flushes denormals to zero and rounds to nearest, so we do
3425   // the same here.
3426   port::ScopedFlushDenormal flush;
3427   port::ScopedSetRound round(FE_TONEAREST);
3428   nodes_to_preserve_ = item.NodesToPreserve();
3429   for (const auto& feed : item.feed) {
3430     feed_nodes_.insert(NodeName(feed.first));
3431   }
3432 
3433   if (cpu_device_ == nullptr) {
3434     owned_device_.reset(new DeviceSimple());
3435     cpu_device_ = owned_device_.get();
3436   }
3437 
3438   graph_contains_assign_or_inplace_op_ = false;
3439   for (const NodeDef& node : item.graph.node()) {
3440     if (ModifiesInputsInPlace(node) || MaybeHasRefInput(node)) {
3441       graph_contains_assign_or_inplace_op_ = true;
3442       break;
3443     }
3444   }
3445 
3446   has_fetch_ = !item.fetch.empty();
3447   GrapplerItem item_to_optimize = item;
3448   *optimized_graph = item.graph;
3449   int64 node_count;
3450   do {
3451     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3452     graph_modified_ = false;
3453     item_to_optimize.graph.Swap(optimized_graph);
3454     graph_ = &item_to_optimize.graph;
3455     *optimized_graph = GraphDef();
3456     node_count = graph_->node_size();
3457     TF_RETURN_IF_ERROR(
3458         RunOptimizationPass(cluster, item_to_optimize, optimized_graph));
3459   } while (graph_modified_ || optimized_graph->node_size() != node_count);
3460   TF_RETURN_IF_ERROR(CompressConstants(optimized_graph));
3461   *optimized_graph->mutable_library() = item.graph.library();
3462   *optimized_graph->mutable_versions() = item.graph.versions();
3463 
3464   return Status::OK();
3465 }
3466 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)3467 void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
3468                                const GraphDef& optimize_output, double result) {
3469   // Nothing to do for ConstantFolding.
3470 }
3471 
3472 }  // namespace grappler
3473 }  // namespace tensorflow
3474