1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/op_types.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/grappler/utils.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/gtl/flatset.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 
IsAdd(const NodeDef & node)29 bool IsAdd(const NodeDef& node) {
30   if (node.op() == "AddV2" || node.op() == "Add") {
31     DataType type = node.attr().at("T").type();
32     return type != DT_STRING;
33   }
34   return false;
35 }
36 
IsAddN(const NodeDef & node)37 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
38 
IsAll(const NodeDef & node)39 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
40 
IsAngle(const NodeDef & node)41 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
42 
IsAny(const NodeDef & node)43 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
44 
IsAnyDiv(const NodeDef & node)45 bool IsAnyDiv(const NodeDef& node) {
46   return node.op() == "RealDiv" || node.op() == "Div" ||
47          node.op() == "FloorDiv" || node.op() == "TruncateDiv";
48 }
49 
IsAnyMax(const NodeDef & node)50 bool IsAnyMax(const NodeDef& node) {
51   const auto& op = node.op();
52   return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
53 }
54 
IsAnyMaxPool(const NodeDef & node)55 bool IsAnyMaxPool(const NodeDef& node) {
56   const auto& op = node.op();
57   return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
58          op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
59 }
60 
IsAnyMin(const NodeDef & node)61 bool IsAnyMin(const NodeDef& node) {
62   const auto& op = node.op();
63   return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
64 }
65 
IsApproximateEqual(const NodeDef & node)66 bool IsApproximateEqual(const NodeDef& node) {
67   return node.op() == "ApproximateEqual";
68 }
69 
IsArgMax(const NodeDef & node)70 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
71 
IsArgMin(const NodeDef & node)72 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
73 
IsAvgPoolGrad(const NodeDef & node)74 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
75 
IsAssign(const NodeDef & node)76 bool IsAssign(const NodeDef& node) {
77   return node.op() == "Assign" || node.op() == "AssignVariableOp";
78 }
79 
IsAssert(const NodeDef & node)80 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
81 
IsAtan2(const NodeDef & node)82 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
83 
IsBetainc(const NodeDef & node)84 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
85 
IsBiasAdd(const NodeDef & node)86 bool IsBiasAdd(const NodeDef& node) {
87   return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
88 }
89 
IsBiasAddGrad(const NodeDef & node)90 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
91 
IsBitcast(const NodeDef & node)92 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
93 
IsCast(const NodeDef & node)94 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
95 
IsCastLike(const NodeDef & node)96 bool IsCastLike(const NodeDef& node) {
97   static const gtl::FlatSet<string>* const kCastLikeOps =
98       CHECK_NOTNULL((new gtl::FlatSet<string>{
99           "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize",
100           "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan",
101           "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2",
102           "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6",
103           "QuantizedReluX", "Real", "Requantize"}));
104   return kCastLikeOps->count(node.op()) > 0;
105 }
106 
IsCheckNumerics(const NodeDef & node)107 bool IsCheckNumerics(const NodeDef& node) {
108   return node.op() == "CheckNumerics";
109 }
110 
IsCollective(const NodeDef & node)111 bool IsCollective(const NodeDef& node) {
112   return node.op() == "CollectiveReduce" ||
113          node.op() == "CollectiveBcastSend" ||
114          node.op() == "CollectiveBcastRecv";
115 }
116 
IsComplex(const NodeDef & node)117 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
118 
IsComplexAbs(const NodeDef & node)119 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
120 
IsConcat(const NodeDef & node)121 bool IsConcat(const NodeDef& node) {
122   return node.op() == "Concat" || node.op() == "ConcatV2";
123 }
124 
IsConcatOffset(const NodeDef & node)125 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
126 
IsConstant(const NodeDef & node)127 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
128 
IsConj(const NodeDef & node)129 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
130 
IsConjugateTranspose(const NodeDef & node)131 bool IsConjugateTranspose(const NodeDef& node) {
132   return node.op() == "ConjugateTranspose";
133 }
134 
IsControlFlow(const NodeDef & node)135 bool IsControlFlow(const NodeDef& node) {
136   // clang-format off
137   return node.op() == "ControlTrigger" ||
138          node.op() == "Enter" ||
139          node.op() == "Exit" ||
140          node.op() == "LoopCond" ||
141          node.op() == "Merge" ||
142          node.op() == "NextIteration" ||
143          node.op() == "Switch";
144   // clang-format on
145 }
146 
IsConv2D(const NodeDef & node)147 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
148 
IsConv2DBackpropFilter(const NodeDef & node)149 bool IsConv2DBackpropFilter(const NodeDef& node) {
150   return node.op() == "Conv2DBackpropFilter";
151 }
152 
IsConv2DBackpropInput(const NodeDef & node)153 bool IsConv2DBackpropInput(const NodeDef& node) {
154   return node.op() == "Conv2DBackpropInput";
155 }
156 
IsConv3D(const NodeDef & node)157 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
158 
IsDepthwiseConv2dNative(const NodeDef & node)159 bool IsDepthwiseConv2dNative(const NodeDef& node) {
160   return node.op() == "DepthwiseConv2dNative";
161 }
162 
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)163 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
164   return node.op() == "DepthwiseConv2dNativeBackpropFilter";
165 }
166 
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)167 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
168   return node.op() == "DepthwiseConv2dNativeBackpropInput";
169 }
170 
IsDequeueOp(const NodeDef & node)171 bool IsDequeueOp(const NodeDef& node) {
172   const auto& op = node.op();
173   return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
174          op == "QueueDequeueV2" || op == "QueueDequeue" ||
175          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
176 }
177 
IsDiv(const NodeDef & node)178 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
179 
180 // Returns true if node represents a unary elementwise function that is
181 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
182 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
183 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)184 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
185   static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
186       CHECK_NOTNULL((new gtl::FlatSet<string>{
187           "Acosh", "Asin", "Asinh",    "Atan",     "Atanh", "Ceil",
188           "Elu",   "Erf",  "Exp",      "Expm1",    "Floor", "Log",
189           "Log1p", "Relu", "Relu6",    "Rint",     "Selu",  "Sigmoid",
190           "Sign",  "Sinh", "Softsign", "Softplus", "Sqrt",  "Tanh",
191       }));
192   static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
193       CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
194   if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
195     if (is_non_decreasing) {
196       *is_non_decreasing = true;
197     }
198     return true;
199   } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
200     if (is_non_decreasing) {
201       *is_non_decreasing = false;
202     }
203     return true;
204   }
205   return false;
206 }
207 
IsEluGrad(const NodeDef & node)208 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
209 
IsEnter(const NodeDef & node)210 bool IsEnter(const NodeDef& node) {
211   const auto& op = node.op();
212   return op == "Enter" || op == "RefEnter";
213 }
214 
IsEqual(const NodeDef & node)215 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
216 
IsExit(const NodeDef & node)217 bool IsExit(const NodeDef& node) {
218   const auto& op = node.op();
219   return op == "Exit" || op == "RefExit";
220 }
221 
IsExp(const NodeDef & node)222 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
223 
IsFakeParam(const NodeDef & node)224 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
225 
IsFill(const NodeDef & node)226 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
227 
IsFloorDiv(const NodeDef & node)228 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
229 
IsFloorMod(const NodeDef & node)230 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
231 
IsFusedBatchNorm(const NodeDef & node)232 bool IsFusedBatchNorm(const NodeDef& node) {
233   const auto& op = node.op();
234   return op == "FusedBatchNorm" || op == "FusedBatchNormV2";
235 }
236 
IsFusedBatchNormGrad(const NodeDef & node)237 bool IsFusedBatchNormGrad(const NodeDef& node) {
238   const auto& op = node.op();
239   return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
240 }
241 
IsGreater(const NodeDef & node)242 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
243 
IsGreaterEqual(const NodeDef & node)244 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
245 
IsHostConstant(const NodeDef & node)246 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
247 
IsHistogramSummary(const NodeDef & node)248 bool IsHistogramSummary(const NodeDef& node) {
249   return node.op() == "HistogramSummary";
250 }
251 
IsIdentity(const NodeDef & node)252 bool IsIdentity(const NodeDef& node) {
253   const auto& op = node.op();
254   return op == "Identity" || op == "RefIdentity";
255 }
256 
IsIdentityN(const NodeDef & node)257 bool IsIdentityN(const NodeDef& node) {
258   const auto& op = node.op();
259   return op == "IdentityN";
260 }
261 
IsIdentityNSingleInput(const NodeDef & node)262 bool IsIdentityNSingleInput(const NodeDef& node) {
263   return IsIdentityN(node) && node.attr().count("T") != 0 &&
264          node.attr().at("T").list().type_size() == 1;
265 }
266 
IsIf(const NodeDef & node)267 bool IsIf(const NodeDef& node) {
268   const auto& op = node.op();
269   return op == "If" || op == "StatelessIf";
270 }
271 
IsIgamma(const NodeDef & node)272 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
273 
IsIgammac(const NodeDef & node)274 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
275 
IsImag(const NodeDef & node)276 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
277 
IsImmutableConst(const NodeDef & node)278 bool IsImmutableConst(const NodeDef& node) {
279   return node.op() == "ImmutableConst";
280 }
281 
IsInvGrad(const NodeDef & node)282 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
283 
IsLess(const NodeDef & node)284 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
285 
IsLessEqual(const NodeDef & node)286 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
287 
IsLog(const NodeDef & node)288 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
289 
IsLogicalAnd(const NodeDef & node)290 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
291 
IsLogicalNot(const NodeDef & node)292 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
293 
IsLogicalOr(const NodeDef & node)294 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
295 
IsMatMul(const NodeDef & node)296 bool IsMatMul(const NodeDef& node) {
297   const auto& op = node.op();
298   return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
299          IsQuantizedMatMul(node);
300 }
301 
IsMax(const NodeDef & node)302 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
303 
IsMaximum(const NodeDef & node)304 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
305 
IsMaxPoolGrad(const NodeDef & node)306 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
307 
IsMean(const NodeDef & node)308 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
309 
IsMerge(const NodeDef & node)310 bool IsMerge(const NodeDef& node) {
311   const auto& op = node.op();
312   return op == "Merge" || op == "RefMerge";
313 }
314 
IsMin(const NodeDef & node)315 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
316 
IsMinimum(const NodeDef & node)317 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
318 
IsMirrorPad(const NodeDef & node)319 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
320 
IsMirrorPadGrad(const NodeDef & node)321 bool IsMirrorPadGrad(const NodeDef& node) {
322   return node.op() == "MirrorPadGrad";
323 }
324 
IsMod(const NodeDef & node)325 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
326 
IsMul(const NodeDef & node)327 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
328 
IsNeg(const NodeDef & node)329 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
330 
IsNoOp(const NodeDef & node)331 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
332 
IsNotEqual(const NodeDef & node)333 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
334 
IsNextIteration(const NodeDef & node)335 bool IsNextIteration(const NodeDef& node) {
336   const auto& op = node.op();
337   return op == "NextIteration" || op == "RefNextIteration";
338 }
339 
IsOnesLike(const NodeDef & node)340 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
341 
IsPack(const NodeDef & node)342 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
343 
IsPad(const NodeDef & node)344 bool IsPad(const NodeDef& node) {
345   const auto& op = node.op();
346   return op == "Pad" || op == "PadV2";
347 }
348 
IsPartitionedCall(const NodeDef & node)349 bool IsPartitionedCall(const NodeDef& node) {
350   return node.op() == "PartitionedCall";
351 }
352 
IsPlaceholder(const NodeDef & node)353 bool IsPlaceholder(const NodeDef& node) {
354   const auto& op = node.op();
355   return op == "Placeholder" || op == "PlaceholderV2" ||
356          op == "PlaceholderWithDefault";
357 }
358 
IsPolygamma(const NodeDef & node)359 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
360 
IsPow(const NodeDef & node)361 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
362 
IsPrint(const NodeDef & node)363 bool IsPrint(const NodeDef& node) {
364   return node.op() == "Print" || node.op() == "PrintV2";
365 }
366 
IsProd(const NodeDef & node)367 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
368 
IsQuantizedMatMul(const NodeDef & node)369 bool IsQuantizedMatMul(const NodeDef& node) {
370   return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
371 }
372 
IsQueue(const NodeDef & node)373 bool IsQueue(const NodeDef& node) {
374   return str_util::EndsWith(node.op(), "QueueV2");
375 }
376 
IsRandomShuffle(const NodeDef & node)377 bool IsRandomShuffle(const NodeDef& node) {
378   return node.op() == "RandomShuffle";
379 }
380 
IsRank(const NodeDef & node)381 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
382 
IsReadVariableOp(const NodeDef & node)383 bool IsReadVariableOp(const NodeDef& node) {
384   return node.op() == "ReadVariableOp";
385 }
386 
IsReal(const NodeDef & node)387 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
388 
IsRealDiv(const NodeDef & node)389 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
390 
IsReciprocalGrad(const NodeDef & node)391 bool IsReciprocalGrad(const NodeDef& node) {
392   return node.op() == "ReciprocalGrad";
393 }
394 
IsRecv(const NodeDef & node)395 bool IsRecv(const NodeDef& node) {
396   return node.op() == "_Recv" || node.op() == "_HostRecv";
397 }
398 
IsReduction(const NodeDef & node)399 bool IsReduction(const NodeDef& node) {
400   const auto& op = node.op();
401   return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
402          op == "Mean" || op == "Any" || op == "All";
403 }
404 
IsRelu(const NodeDef & node)405 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
406 
IsReluGrad(const NodeDef & node)407 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
408 
IsRelu6Grad(const NodeDef & node)409 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
410 
IsReshape(const NodeDef & node)411 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
412 
IsRestore(const NodeDef & node)413 bool IsRestore(const NodeDef& node) {
414   return (node.op() == "Restore" || node.op() == "RestoreV2" ||
415           node.op() == "RestoreSlice");
416 }
417 
IsReverse(const NodeDef & node)418 bool IsReverse(const NodeDef& node) {
419   return node.op() == "Reverse" || node.op() == "ReverseV2";
420 }
421 
IsReverseV2(const NodeDef & node)422 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
423 
IsRsqrt(const NodeDef & node)424 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
425 
IsRsqrtGrad(const NodeDef & node)426 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
427 
IsSelect(const NodeDef & node)428 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
429 
IsSeluGrad(const NodeDef & node)430 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
431 
IsSend(const NodeDef & node)432 bool IsSend(const NodeDef& node) {
433   return node.op() == "_Send" || node.op() == "_HostSend";
434 }
435 
IsShape(const NodeDef & node)436 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
437 
IsShapeN(const NodeDef & node)438 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
439 
IsShuffle(const NodeDef & node)440 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
441 
IsSigmoidGrad(const NodeDef & node)442 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
443 
IsSize(const NodeDef & node)444 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
445 
IsSlice(const NodeDef & node)446 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
447 
IsSnapshot(const NodeDef & node)448 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
449 
IsSoftmax(const NodeDef & node)450 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
451 
IsSoftplusGrad(const NodeDef & node)452 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
453 
IsSoftsignGrad(const NodeDef & node)454 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
455 
IsSplit(const NodeDef & node)456 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
457 
IsSplitV(const NodeDef & node)458 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
459 
IsSqrt(const NodeDef & node)460 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
461 
IsSqrtGrad(const NodeDef & node)462 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
463 
IsSquare(const NodeDef & node)464 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
465 
IsSquaredDifference(const NodeDef & node)466 bool IsSquaredDifference(const NodeDef& node) {
467   return node.op() == "SquaredDifference";
468 }
469 
IsSqueeze(const NodeDef & node)470 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
471 
IsStackOp(const NodeDef & node)472 bool IsStackOp(const NodeDef& node) {
473   return node.op() == "Stack" || node.op() == "StackV2";
474 }
IsStackCloseOp(const NodeDef & node)475 bool IsStackCloseOp(const NodeDef& node) {
476   return node.op() == "StackClose" || node.op() == "StackCloseV2";
477 }
IsStackPushOp(const NodeDef & node)478 bool IsStackPushOp(const NodeDef& node) {
479   return node.op() == "StackPush" || node.op() == "StackPushV2";
480 }
IsStackPopOp(const NodeDef & node)481 bool IsStackPopOp(const NodeDef& node) {
482   return node.op() == "StackPop" || node.op() == "StackPopV2";
483 }
484 
IsStatefulPartitionedCall(const NodeDef & node)485 bool IsStatefulPartitionedCall(const NodeDef& node) {
486   return node.op() == "StatefulPartitionedCall";
487 }
488 
IsStopGradient(const NodeDef & node)489 bool IsStopGradient(const NodeDef& node) {
490   const auto& op = node.op();
491   return op == "StopGradient" || op == "PreventGradient";
492 }
493 
IsStridedSlice(const NodeDef & node)494 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
495 
IsStridedSliceGrad(const NodeDef & node)496 bool IsStridedSliceGrad(const NodeDef& node) {
497   return node.op() == "StridedSliceGrad";
498 }
499 
IsSub(const NodeDef & node)500 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
501 
IsSum(const NodeDef & node)502 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
503 
IsSwitch(const NodeDef & node)504 bool IsSwitch(const NodeDef& node) {
505   const auto& op = node.op();
506   return op == "Switch" || op == "RefSwitch";
507 }
508 
IsSymbolicGradient(const NodeDef & node)509 bool IsSymbolicGradient(const NodeDef& node) {
510   return node.op() == "SymbolicGradient";
511 }
512 
IsTanhGrad(const NodeDef & node)513 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
514 
IsTensorArray(const NodeDef & node)515 bool IsTensorArray(const NodeDef& node) {
516   static const gtl::FlatSet<string>* const kTensorArrayOps =
517       CHECK_NOTNULL((new gtl::FlatSet<string>{
518           "TensorArray",
519           "TensorArrayV2",
520           "TensorArrayV3",
521           "TensorArrayGrad",
522           "TensorArrayGradV2",
523           "TensorArrayGradV3",
524           "TensorArrayGradWithShape",
525           "TensorArrayWrite",
526           "TensorArrayWriteV2",
527           "TensorArrayWriteV3",
528           "TensorArrayRead",
529           "TensorArrayReadV2",
530           "TensorArrayReadV3",
531           "TensorArrayConcat",
532           "TensorArrayConcatV2",
533           "TensorArrayConcatV3",
534           "TensorArraySplit",
535           "TensorArraySplitV2",
536           "TensorArraySplitV3",
537           "TensorArraySize",
538           "TensorArraySizeV2",
539           "TensorArraySizeV3",
540           "TensorArrayClose",
541           "TensorArrayCloseV2",
542           "TensorArrayCloseV3",
543       }));
544   return kTensorArrayOps->count(node.op()) > 0;
545 }
546 
IsTile(const NodeDef & node)547 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
548 
IsTranspose(const NodeDef & node)549 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
550 
IsTruncateDiv(const NodeDef & node)551 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
552 
IsTruncateMod(const NodeDef & node)553 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
554 
IsUnpack(const NodeDef & node)555 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
556 
IsVariable(const NodeDef & node)557 bool IsVariable(const NodeDef& node) {
558   const auto& op = node.op();
559   return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
560          op == "VarHandleOp" || op == "ReadVariableOp";
561 }
562 
IsWhile(const NodeDef & node)563 bool IsWhile(const NodeDef& node) {
564   const auto& op = node.op();
565   return op == "While" || op == "StatelessWhile";
566 }
567 
IsZerosLike(const NodeDef & node)568 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
569 
IsZeta(const NodeDef & node)570 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
571 
572 namespace {
GetBoolAttr(const NodeDef & node,const string & name)573 bool GetBoolAttr(const NodeDef& node, const string& name) {
574   return node.attr().count(name) > 0 && node.attr().at(name).b();
575 }
576 }  // namespace
577 
IsPersistent(const NodeDef & node)578 bool IsPersistent(const NodeDef& node) {
579   return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
580 }
581 
MaybeHasRefInput(const NodeDef & node)582 bool MaybeHasRefInput(const NodeDef& node) {
583   const OpDef* op_def;
584   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
585   if (!status.ok()) {
586     return true;
587   }
588   // Nodes such as Assign or AssignAdd modify one of their inputs.
589   for (const auto& input : op_def->input_arg()) {
590     if (input.is_ref()) {
591       return true;
592     }
593   }
594   return false;
595 }
596 
IsDataset(const NodeDef & node)597 bool IsDataset(const NodeDef& node) {
598   const string& op = node.op();
599   // See `GetNodeClassForOp` in core/graph/graph.cc.
600   return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
601          op == "DatasetToSingleElement" || op == "ReduceDataset";
602 }
603 
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)604 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
605   const OpDef* op_def = nullptr;
606   const string& op_name = node.op();
607   Status status = op_registry->LookUpOpDef(op_name, &op_def);
608   if (!status.ok()) {
609     LOG(WARNING) << "Failed to lookup OpDef for " << op_name
610                  << ". Error: " << status.error_message();
611     return false;
612   }
613   return op_def->is_stateful();
614 }
615 
IsStateful(const NodeDef node)616 bool IsStateful(const NodeDef node) {
617   return IsStateful(node, OpRegistry::Global());
618 }
619 
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)620 bool IsFreeOfSideEffect(const NodeDef& node,
621                         const OpRegistryInterface* op_registry) {
622   // Placeholders must be preserved to keep the graph feedable.
623   if (IsPlaceholder(node)) {
624     return false;
625   }
626   const OpDef* op_def = nullptr;
627   const string& op_name = node.op();
628   Status status = op_registry->LookUpOpDef(op_name, &op_def);
629   if (!status.ok()) {
630     return false;
631   }
632   if (op_def->is_stateful()) {
633     return false;
634   }
635   // Nodes such as Assign or AssignAdd modify one of their inputs.
636   for (const auto& input : op_def->input_arg()) {
637     if (input.is_ref()) {
638       return false;
639     }
640   }
641   // Queue ops modify the queue which is a side effect.
642   if (node.op().find("Queue") != string::npos) {
643     return false;
644   }
645   // Sending a tensor via a network is a side effect.
646   if (IsSend(node)) {
647     return false;
648   }
649   return !ModifiesInputsInPlace(node);
650 }
651 
IsFreeOfSideEffect(const NodeDef & node)652 bool IsFreeOfSideEffect(const NodeDef& node) {
653   return IsFreeOfSideEffect(node, OpRegistry::Global());
654 }
655 
ModifiesInputsInPlace(const NodeDef & node)656 bool ModifiesInputsInPlace(const NodeDef& node) {
657   // Some nodes do in-place updates on regular tensor inputs.
658   string op_name = node.op();
659 
660   // Ops that modify resource variables effectively modify one of their inputs.
661   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
662       op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
663       op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
664       op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
665       op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
666     return false;
667   }
668 
669   std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
670   if (str_util::StrContains(op_name, "inplace")) {
671     return true;
672   }
673   return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
674 }
675 
ModifiesFrameInfo(const NodeDef & node)676 bool ModifiesFrameInfo(const NodeDef& node) {
677   return IsEnter(node) || IsExit(node) || IsNextIteration(node);
678 }
679 
680 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY)                      \
681   bool Is##PROPERTY_CAP(const NodeDef& node) {                             \
682     if (node.op() == "Add") {                                              \
683       /* Workaround for "Add" not being marked is_commutative and */       \
684       /* is_aggregate. (See cl/173915048). */                              \
685       const auto type = GetDataTypeFromAttr(node, "T");                    \
686       return type != DT_INVALID && type != DT_STRING;                      \
687     }                                                                      \
688     const OpDef* op_def = nullptr;                                         \
689     Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
690     return status.ok() && op_def->is_##PROPERTY();                         \
691   }
692 
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)693 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
694 OPDEF_PROPERTY_HELPER(Commutative, commutative)
695 
696 bool IsInvolution(const NodeDef& node) {
697   static const gtl::FlatSet<string>* const kInvolutionOps =
698       CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
699                                               "Neg", "LogicalNot"}));
700   return kInvolutionOps->count(node.op()) > 0;
701 }
702 
IsValueAndOrderAndShapePreserving(const NodeDef & node)703 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
704   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
705     return true;
706   }
707   static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
708       CHECK_NOTNULL((new const gtl::FlatSet<string>{
709           "CheckNumerics",
710           "DebugGradientIdentity",
711           "DeepCopy"
712           "Enter",
713           "Exit",
714           "PreventGradient",
715           "Print",
716           "Snapshot",
717           "StopGradient",
718       }));
719   return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
720          IsIdentity(node);
721 }
722 
IsValueAndOrderPreserving(const NodeDef & node)723 bool IsValueAndOrderPreserving(const NodeDef& node) {
724   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
725     return true;
726   }
727   static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
728       CHECK_NOTNULL((new const gtl::FlatSet<string>{
729           "ExpandDims",
730           "Reshape",
731           "Squeeze",
732       }));
733   return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
734          IsValueAndOrderAndShapePreserving(node);
735 }
736 
IsValuePreserving(const NodeDef & node)737 bool IsValuePreserving(const NodeDef& node) {
738   static const gtl::FlatSet<string>* const kValuePreservingOps =
739       CHECK_NOTNULL((new gtl::FlatSet<string>{
740           "InvertPermutation",
741           "Reverse",
742           "ReverseV2",
743           "Roll",
744           "Transpose",
745           "DepthToSpace",
746           "SpaceToDepth",
747           "BatchToSpace",
748           "BatchToSpaceND",
749           "SpaceToBatch",
750           "SpaceToBatchND",
751       }));
752   return IsValueAndOrderPreserving(node) ||
753          kValuePreservingOps->count(node.op()) > 0;
754 }
755 
IsUnaryElementWise(const NodeDef & node)756 bool IsUnaryElementWise(const NodeDef& node) {
757   static const gtl::FlatSet<string>* const kElementWiseOps =
758       CHECK_NOTNULL((new gtl::FlatSet<string>{
759           "Abs",
760           "Acos",
761           "Acosh",
762           "Asin",
763           "Asinh",
764           "Atan",
765           "Atanh",
766           "Ceil",
767           "ComplexAbs",
768           "Conj",
769           "Cos",
770           "Cosh",
771           "Digamma",
772           "Elu"
773           "Erf",
774           "Erfc",
775           "Exp",
776           "Expm1",
777           "Floor",
778           "Inv",
779           "Invert",
780           "Isinf",
781           "Isnan",
782           "Isfinite",
783           "Lgamma",
784           "Log",
785           "Log1p",
786           "LogicalNot",
787           "Neg",
788           "Reciprocal",
789           "Relu",
790           "Relu6",
791           "Rint",
792           "Round",
793           "Selu",
794           "Rsqrt",
795           "Sigmoid",
796           "Sign",
797           "Sin",
798           "SinH",
799           "Softplus",
800           "Softsign",
801           "Sqrt",
802           "Square",
803           "Tan"
804           "Tanh",
805       }));
806   return kElementWiseOps->count(node.op()) > 0 ||
807          IsValueAndOrderAndShapePreserving(node);
808 }
809 
HasOpDef(const NodeDef & node)810 bool HasOpDef(const NodeDef& node) {
811   const OpDef* op_def = nullptr;
812   return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
813 }
814 
IsIdempotent(const NodeDef & node)815 bool IsIdempotent(const NodeDef& node) {
816   return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
817          !ModifiesFrameInfo(node);
818 }
819 
NeverForwardsInputs(const NodeDef & node)820 bool NeverForwardsInputs(const NodeDef& node) {
821   static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
822       (new gtl::FlatSet<string>{"ArgMax",
823                                 "ArgMin",
824                                 "AudioSpectrogram",
825                                 "BatchMatMul",
826                                 "BatchToSpace",
827                                 "BatchToSpaceND",
828                                 "Bincount",
829                                 "BroadcastArgs",
830                                 "BroadcastGradientArgs",
831                                 "CTCBeamSearchDecoder",
832                                 "CTCGreedyDecoder",
833                                 "CTCLoss",
834                                 "ComplexAbs",
835                                 "Concat",
836                                 "ConcatOffset",
837                                 "ConcatV2",
838                                 "Copy",
839                                 "CopyHost",
840                                 "Cross",
841                                 "CudnnRNN",
842                                 "CudnnRNNBackprop",
843                                 "CudnnRNNBackpropV2",
844                                 "CudnnRNNBackpropV3",
845                                 "CudnnRNNCanonicalToParams",
846                                 "CudnnRNNParamsSize",
847                                 "CudnnRNNParamsToCanonical",
848                                 "CudnnRNNV2",
849                                 "CudnnRNNV3",
850                                 "CumSum",
851                                 "CumProd",
852                                 "DebugNanCount",
853                                 "DebugNumericSummary",
854                                 "DecodeProtoV2",
855                                 "DecodeWav",
856                                 "DeepCopy",
857                                 "DepthToSpace",
858                                 "Dequantize",
859                                 "Diag",
860                                 "DiagPart",
861                                 "EditDistance",
862                                 "Empty",
863                                 "EncodeProtoV2",
864                                 "EncodeWav",
865                                 "ExtractImagePatches",
866                                 "ExtractVolumePatches",
867                                 "Fill",
868                                 "Gather",
869                                 "GatherNd",
870                                 "GatherV2",
871                                 "HistogramFixedWidth",
872                                 "InvertPermutation",
873                                 "IsInf",
874                                 "IsNan",
875                                 "Isfinite",
876                                 "LinSpace",
877                                 "LowerBound",
878                                 "MatMul",
879                                 "MatrixDiag",
880                                 "MatrixDiagPart",
881                                 "Mfcc",
882                                 "OneHot",
883                                 "Pack",
884                                 "PopulationCount",
885                                 "Range",
886                                 "Rank",
887                                 "ReverseSequence",
888                                 "Shape",
889                                 "ShapeN",
890                                 "Size",
891                                 "SpaceToBatch",
892                                 "SpaceToBatchND",
893                                 "SpaceToDepth",
894                                 "SparseMatMul",
895                                 "Split",
896                                 "SplitV",
897                                 "Unique",
898                                 "UniqueV2",
899                                 "UniqueWithCounts",
900                                 "UniqueWithCountsV2",
901                                 "Unpack",
902                                 "UnravelIndex",
903                                 "UpperBound",
904                                 "Where",
905                                 "CompareAndBitpack",
906                                 "Requantize",
907                                 "RequantizationRange",
908                                 "Bucketize",
909                                 "AvgPool",
910                                 "BatchNormWithGlobalNormalization",
911                                 "FusedBatchNorm",
912                                 "FusedBatchNormV2",
913                                 "Conv2D",
914                                 "RandomUniform",
915                                 "RandomUniformInt",
916                                 "RandomStandardNormal",
917                                 "ParameterizedTruncatedNormal",
918                                 "TruncatedNormal",
919                                 "Multinomial",
920                                 "RandomGamma",
921                                 "RandomPoisson",
922                                 "RandomPoissonV2"}));
923   const string& op_name = node.op();
924   return kNonForwardingOps->count(op_name) > 0 ||
925          str_util::StrContains(op_name, "Segment") ||
926          str_util::StartsWith(op_name, "Quantize");
927 }
928 
929 }  // namespace grappler
930 }  // end namespace tensorflow
931