1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/op_types.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/utils.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/gtl/flatset.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
IsAdd(const NodeDef & node)30 bool IsAdd(const NodeDef& node) {
31   if (node.op() == "AddV2") {
32     return true;
33   }
34   if (node.op() == "Add") {
35     DataType type = node.attr().at("T").type();
36     return type != DT_STRING;
37   }
38   return false;
39 }
40 
IsAddN(const NodeDef & node)41 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
42 
IsAll(const NodeDef & node)43 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
44 
IsAngle(const NodeDef & node)45 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
46 
IsAny(const NodeDef & node)47 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
48 
IsAnyDiv(const NodeDef & node)49 bool IsAnyDiv(const NodeDef& node) {
50   return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" ||
51          node.op() == "FloorDiv" || node.op() == "TruncateDiv";
52 }
53 
IsAnyBatchMatMul(const NodeDef & node)54 bool IsAnyBatchMatMul(const NodeDef& node) {
55   return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2";
56 }
57 
IsAnyMatMul(const NodeDef & node)58 bool IsAnyMatMul(const NodeDef& node) {
59   return node.op() == "MatMul" || node.op() == "SparseMatMul" ||
60          IsAnyBatchMatMul(node) || IsQuantizedMatMul(node);
61 }
62 
IsAnyMax(const NodeDef & node)63 bool IsAnyMax(const NodeDef& node) {
64   const auto& op = node.op();
65   return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
66 }
67 
IsAnyMaxPool(const NodeDef & node)68 bool IsAnyMaxPool(const NodeDef& node) {
69   const auto& op = node.op();
70   return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
71          op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
72 }
73 
IsAnyMin(const NodeDef & node)74 bool IsAnyMin(const NodeDef& node) {
75   const auto& op = node.op();
76   return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
77 }
78 
IsAnySparseSegmentReduction(const NodeDef & node)79 bool IsAnySparseSegmentReduction(const NodeDef& node) {
80   const auto& op = node.op();
81   return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
82          op == "SparseSegmentMean" ||
83          op == "SparseSegmentMeanWithNumSegments" ||
84          op == "SparseSegmentSqrtN" ||
85          op == "SparseSegmentSqrtNWithNumSegments";
86 }
87 
IsApproximateEqual(const NodeDef & node)88 bool IsApproximateEqual(const NodeDef& node) {
89   return node.op() == "ApproximateEqual";
90 }
91 
IsArg(const NodeDef & node)92 bool IsArg(const NodeDef& node) {
93   return node.op() == "_Arg" || node.op() == "_DeviceArg";
94 }
95 
IsArgMax(const NodeDef & node)96 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
97 
IsArgMin(const NodeDef & node)98 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
99 
IsAvgPoolGrad(const NodeDef & node)100 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
101 
IsAssign(const NodeDef & node)102 bool IsAssign(const NodeDef& node) {
103   return node.op() == "Assign" || node.op() == "AssignVariableOp";
104 }
105 
IsAssert(const NodeDef & node)106 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
107 
IsAtan2(const NodeDef & node)108 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
109 
IsBetainc(const NodeDef & node)110 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
111 
IsBiasAdd(const NodeDef & node)112 bool IsBiasAdd(const NodeDef& node) {
113   return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
114 }
115 
IsBiasAddV2(const NodeDef & node)116 bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
117 
IsBiasAddGrad(const NodeDef & node)118 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
119 
IsBitcast(const NodeDef & node)120 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
121 
IsBroadcastTo(const NodeDef & node)122 bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
123 
IsCast(const NodeDef & node)124 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
125 
IsCastLike(const NodeDef & node)126 bool IsCastLike(const NodeDef& node) {
127   static const gtl::FlatSet<string>* const kCastLikeOps =
128       CHECK_NOTNULL((new gtl::FlatSet<string>{
129           "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize",
130           "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan",
131           "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2",
132           "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6",
133           "QuantizedReluX", "Real", "Requantize"}));
134   return kCastLikeOps->count(node.op()) > 0;
135 }
136 
IsCheckNumerics(const NodeDef & node)137 bool IsCheckNumerics(const NodeDef& node) {
138   return node.op() == "CheckNumerics";
139 }
140 
IsCollective(const NodeDef & node)141 bool IsCollective(const NodeDef& node) {
142   return node.op() == "CollectiveReduce" ||
143          node.op() == "CollectiveBcastSend" ||
144          node.op() == "CollectiveBcastRecv";
145 }
146 
IsComplex(const NodeDef & node)147 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
148 
IsComplexAbs(const NodeDef & node)149 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
150 
IsConcat(const NodeDef & node)151 bool IsConcat(const NodeDef& node) {
152   return node.op() == "Concat" || node.op() == "ConcatV2";
153 }
154 
IsConcatOffset(const NodeDef & node)155 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
156 
IsConstant(const NodeDef & node)157 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
158 
IsConj(const NodeDef & node)159 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
160 
IsConjugateTranspose(const NodeDef & node)161 bool IsConjugateTranspose(const NodeDef& node) {
162   return node.op() == "ConjugateTranspose";
163 }
164 
IsControlFlow(const NodeDef & node)165 bool IsControlFlow(const NodeDef& node) {
166   // clang-format off
167   return node.op() == "ControlTrigger" ||
168          node.op() == "Enter" ||
169          node.op() == "Exit" ||
170          node.op() == "LoopCond" ||
171          node.op() == "Merge" ||
172          node.op() == "_XlaMerge" ||
173          node.op() == "NextIteration" ||
174          node.op() == "Switch" ||
175          node.op() == "_SwitchN";
176   // clang-format on
177 }
178 
IsConv2D(const NodeDef & node)179 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
180 
IsConv2DBackpropFilter(const NodeDef & node)181 bool IsConv2DBackpropFilter(const NodeDef& node) {
182   return node.op() == "Conv2DBackpropFilter";
183 }
184 
IsConv2DBackpropInput(const NodeDef & node)185 bool IsConv2DBackpropInput(const NodeDef& node) {
186   return node.op() == "Conv2DBackpropInput";
187 }
188 
IsConv3D(const NodeDef & node)189 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
190 
IsConv3DBackpropFilterV2(const NodeDef & node)191 bool IsConv3DBackpropFilterV2(const NodeDef& node) {
192   return node.op() == "Conv3DBackpropFilterV2";
193 }
194 
IsConv3DBackpropInputV2(const NodeDef & node)195 bool IsConv3DBackpropInputV2(const NodeDef& node) {
196   return node.op() == "Conv3DBackpropInputV2";
197 }
198 
IsDepthwiseConv2dNative(const NodeDef & node)199 bool IsDepthwiseConv2dNative(const NodeDef& node) {
200   return node.op() == "DepthwiseConv2dNative";
201 }
202 
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)203 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
204   return node.op() == "DepthwiseConv2dNativeBackpropFilter";
205 }
206 
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)207 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
208   return node.op() == "DepthwiseConv2dNativeBackpropInput";
209 }
210 
IsDequeueOp(const NodeDef & node)211 bool IsDequeueOp(const NodeDef& node) {
212   const auto& op = node.op();
213   return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
214          op == "QueueDequeueV2" || op == "QueueDequeue" ||
215          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
216 }
217 
IsDiv(const NodeDef & node)218 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
219 
IsDivNoNan(const NodeDef & node)220 bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan"; }
221 
222 // Returns true if node represents a unary elementwise function that is
223 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
224 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
225 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)226 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
227   static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
228       CHECK_NOTNULL((new gtl::FlatSet<string>{
229           "Acosh", "Asin", "Asinh",    "Atan",     "Atanh", "Ceil",
230           "Elu",   "Erf",  "Exp",      "Expm1",    "Floor", "Log",
231           "Log1p", "Relu", "Relu6",    "Rint",     "Selu",  "Sigmoid",
232           "Sign",  "Sinh", "Softsign", "Softplus", "Sqrt",  "Tanh",
233       }));
234   static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
235       CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
236   if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
237     if (is_non_decreasing) {
238       *is_non_decreasing = true;
239     }
240     return true;
241   } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
242     if (is_non_decreasing) {
243       *is_non_decreasing = false;
244     }
245     return true;
246   }
247   return false;
248 }
249 
IsElu(const NodeDef & node)250 bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
251 
IsEluGrad(const NodeDef & node)252 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
253 
IsQuantizationEmulation(const NodeDef & node)254 bool IsQuantizationEmulation(const NodeDef& node) {
255   const auto& op = node.op();
256   return absl::StartsWith(op, "QuantizeAndDequantize") ||
257          absl::StartsWith(op, "FakeQuantWithMinMax");
258 }
259 
IsEnter(const NodeDef & node)260 bool IsEnter(const NodeDef& node) {
261   const auto& op = node.op();
262   return op == "Enter" || op == "RefEnter";
263 }
264 
IsEqual(const NodeDef & node)265 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
266 
IsExit(const NodeDef & node)267 bool IsExit(const NodeDef& node) {
268   const auto& op = node.op();
269   return op == "Exit" || op == "RefExit";
270 }
271 
IsExp(const NodeDef & node)272 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
273 
IsFakeParam(const NodeDef & node)274 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
275 
IsFill(const NodeDef & node)276 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
277 
IsFloorDiv(const NodeDef & node)278 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
279 
IsFloorMod(const NodeDef & node)280 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
281 
IsFusedBatchNorm(const NodeDef & node)282 bool IsFusedBatchNorm(const NodeDef& node) {
283   const auto& op = node.op();
284   return op == "FusedBatchNorm" || op == "FusedBatchNormV2" ||
285          op == "FusedBatchNormV3";
286 }
287 
IsFusedBatchNormEx(const NodeDef & node)288 bool IsFusedBatchNormEx(const NodeDef& node) {
289   return node.op() == "_FusedBatchNormEx";
290 }
291 
IsFusedBatchNormGrad(const NodeDef & node)292 bool IsFusedBatchNormGrad(const NodeDef& node) {
293   const auto& op = node.op();
294   return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" ||
295          op == "FusedBatchNormGradV3";
296 }
297 
IsGather(const NodeDef & node)298 bool IsGather(const NodeDef& node) {
299   const auto& op = node.op();
300   return op == "Gather" || op == "GatherV2";
301 }
302 
IsGreater(const NodeDef & node)303 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
304 
IsGreaterEqual(const NodeDef & node)305 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
306 
IsHostConstant(const NodeDef & node)307 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
308 
IsHistogramSummary(const NodeDef & node)309 bool IsHistogramSummary(const NodeDef& node) {
310   return node.op() == "HistogramSummary";
311 }
312 
IsIdentity(const NodeDef & node)313 bool IsIdentity(const NodeDef& node) {
314   const auto& op = node.op();
315   return op == "Identity" || op == "RefIdentity";
316 }
317 
IsIdentityN(const NodeDef & node)318 bool IsIdentityN(const NodeDef& node) {
319   const auto& op = node.op();
320   return op == "IdentityN";
321 }
322 
IsIdentityNSingleInput(const NodeDef & node)323 bool IsIdentityNSingleInput(const NodeDef& node) {
324   return IsIdentityN(node) && node.attr().count("T") != 0 &&
325          node.attr().at("T").list().type_size() == 1;
326 }
327 
IsIf(const NodeDef & node)328 bool IsIf(const NodeDef& node) {
329   const auto& op = node.op();
330   return op == "If" || op == "StatelessIf";
331 }
332 
IsIgamma(const NodeDef & node)333 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
334 
IsIgammac(const NodeDef & node)335 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
336 
IsImag(const NodeDef & node)337 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
338 
IsImmutableConst(const NodeDef & node)339 bool IsImmutableConst(const NodeDef& node) {
340   return node.op() == "ImmutableConst";
341 }
342 
IsInvGrad(const NodeDef & node)343 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
344 
IsLeakyRelu(const NodeDef & node)345 bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
346 
IsLeakyReluGrad(const NodeDef & node)347 bool IsLeakyReluGrad(const NodeDef& node) {
348   return node.op() == "LeakyReluGrad";
349 }
350 
IsLess(const NodeDef & node)351 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
352 
IsLessEqual(const NodeDef & node)353 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
354 
IsLog(const NodeDef & node)355 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
356 
IsLogicalAnd(const NodeDef & node)357 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
358 
IsLogicalNot(const NodeDef & node)359 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
360 
IsLogicalOr(const NodeDef & node)361 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
362 
IsLoopCond(const NodeDef & node)363 bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
364 
IsMatMul(const NodeDef & node)365 bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
366 
IsMax(const NodeDef & node)367 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
368 
IsMaximum(const NodeDef & node)369 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
370 
IsMaxPoolGrad(const NodeDef & node)371 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
372 
IsMean(const NodeDef & node)373 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
374 
IsMerge(const NodeDef & node)375 bool IsMerge(const NodeDef& node) {
376   const auto& op = node.op();
377   return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
378 }
379 
IsMin(const NodeDef & node)380 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
381 
IsMinimum(const NodeDef & node)382 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
383 
IsMirrorPad(const NodeDef & node)384 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
385 
IsMirrorPadGrad(const NodeDef & node)386 bool IsMirrorPadGrad(const NodeDef& node) {
387   return node.op() == "MirrorPadGrad";
388 }
389 
IsMod(const NodeDef & node)390 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
391 
IsMul(const NodeDef & node)392 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
IsMulNoNan(const NodeDef & node)393 bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan"; }
IsAnyMul(const NodeDef & node)394 bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); }
395 
IsNeg(const NodeDef & node)396 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
397 
IsNoOp(const NodeDef & node)398 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
399 
IsNotEqual(const NodeDef & node)400 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
401 
IsNextIteration(const NodeDef & node)402 bool IsNextIteration(const NodeDef& node) {
403   const auto& op = node.op();
404   return op == "NextIteration" || op == "RefNextIteration";
405 }
406 
IsOnesLike(const NodeDef & node)407 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
408 
IsPack(const NodeDef & node)409 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
410 
IsPad(const NodeDef & node)411 bool IsPad(const NodeDef& node) {
412   const auto& op = node.op();
413   return op == "Pad" || op == "PadV2";
414 }
415 
IsPartitionedCall(const NodeDef & node)416 bool IsPartitionedCall(const NodeDef& node) {
417   return node.op() == "PartitionedCall";
418 }
419 
IsPlaceholder(const NodeDef & node)420 bool IsPlaceholder(const NodeDef& node) {
421   const auto& op = node.op();
422   return op == "Placeholder" || op == "PlaceholderV2" ||
423          op == "PlaceholderWithDefault";
424 }
425 
IsPolygamma(const NodeDef & node)426 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
427 
IsPow(const NodeDef & node)428 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
429 
IsPrint(const NodeDef & node)430 bool IsPrint(const NodeDef& node) {
431   return node.op() == "Print" || node.op() == "PrintV2";
432 }
433 
IsProd(const NodeDef & node)434 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
435 
IsQuantizedMatMul(const NodeDef & node)436 bool IsQuantizedMatMul(const NodeDef& node) {
437   return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
438 }
439 
IsQueue(const NodeDef & node)440 bool IsQueue(const NodeDef& node) {
441   return str_util::EndsWith(node.op(), "QueueV2");
442 }
443 
IsRandomShuffle(const NodeDef & node)444 bool IsRandomShuffle(const NodeDef& node) {
445   return node.op() == "RandomShuffle";
446 }
447 
IsRank(const NodeDef & node)448 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
449 
IsReadVariableOp(const NodeDef & node)450 bool IsReadVariableOp(const NodeDef& node) {
451   return node.op() == "ReadVariableOp";
452 }
453 
IsReadVariablesOp(const NodeDef & node)454 bool IsReadVariablesOp(const NodeDef& node) {
455   return node.op() == "_ReadVariablesOp";
456 }
457 
IsReal(const NodeDef & node)458 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
459 
IsRealDiv(const NodeDef & node)460 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
461 
IsReciprocalGrad(const NodeDef & node)462 bool IsReciprocalGrad(const NodeDef& node) {
463   return node.op() == "ReciprocalGrad";
464 }
465 
IsRecv(const NodeDef & node)466 bool IsRecv(const NodeDef& node) {
467   return node.op() == "_Recv" || node.op() == "_HostRecv";
468 }
469 
IsReduction(const NodeDef & node)470 bool IsReduction(const NodeDef& node) {
471   const auto& op = node.op();
472   return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
473          op == "Mean" || op == "Any" || op == "All";
474 }
475 
IsRelu(const NodeDef & node)476 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
477 
IsRelu6(const NodeDef & node)478 bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
479 
IsReluGrad(const NodeDef & node)480 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
481 
IsRelu6Grad(const NodeDef & node)482 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
483 
IsReshape(const NodeDef & node)484 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
485 
IsRestore(const NodeDef & node)486 bool IsRestore(const NodeDef& node) {
487   return (node.op() == "Restore" || node.op() == "RestoreV2" ||
488           node.op() == "RestoreSlice");
489 }
490 
IsRetval(const NodeDef & node)491 bool IsRetval(const NodeDef& node) {
492   return node.op() == "_Retval" || node.op() == "_DeviceRetval";
493 }
494 
IsReverse(const NodeDef & node)495 bool IsReverse(const NodeDef& node) {
496   return node.op() == "Reverse" || node.op() == "ReverseV2";
497 }
498 
IsReverseV2(const NodeDef & node)499 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
500 
IsRsqrt(const NodeDef & node)501 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
502 
IsRsqrtGrad(const NodeDef & node)503 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
504 
IsSelect(const NodeDef & node)505 bool IsSelect(const NodeDef& node) {
506   return node.op() == "Select" || node.op() == "SelectV2";
507 }
508 
IsSeluGrad(const NodeDef & node)509 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
510 
IsSend(const NodeDef & node)511 bool IsSend(const NodeDef& node) {
512   return node.op() == "_Send" || node.op() == "_HostSend";
513 }
514 
IsShape(const NodeDef & node)515 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
516 
IsShapeN(const NodeDef & node)517 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
518 
IsShuffle(const NodeDef & node)519 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
520 
IsSigmoidGrad(const NodeDef & node)521 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
522 
IsSize(const NodeDef & node)523 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
524 
IsSlice(const NodeDef & node)525 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
526 
IsSnapshot(const NodeDef & node)527 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
528 
IsSoftmax(const NodeDef & node)529 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
530 
IsSoftplusGrad(const NodeDef & node)531 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
532 
IsSoftsignGrad(const NodeDef & node)533 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
534 
IsSplit(const NodeDef & node)535 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
536 
IsSplitV(const NodeDef & node)537 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
538 
IsSqrt(const NodeDef & node)539 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
540 
IsSqrtGrad(const NodeDef & node)541 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
542 
IsSquare(const NodeDef & node)543 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
544 
IsSquaredDifference(const NodeDef & node)545 bool IsSquaredDifference(const NodeDef& node) {
546   return node.op() == "SquaredDifference";
547 }
548 
IsSqueeze(const NodeDef & node)549 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
550 
IsStackOp(const NodeDef & node)551 bool IsStackOp(const NodeDef& node) {
552   return node.op() == "Stack" || node.op() == "StackV2";
553 }
IsStackCloseOp(const NodeDef & node)554 bool IsStackCloseOp(const NodeDef& node) {
555   return node.op() == "StackClose" || node.op() == "StackCloseV2";
556 }
IsStackPushOp(const NodeDef & node)557 bool IsStackPushOp(const NodeDef& node) {
558   return node.op() == "StackPush" || node.op() == "StackPushV2";
559 }
IsStackPopOp(const NodeDef & node)560 bool IsStackPopOp(const NodeDef& node) {
561   return node.op() == "StackPop" || node.op() == "StackPopV2";
562 }
563 
IsStatefulPartitionedCall(const NodeDef & node)564 bool IsStatefulPartitionedCall(const NodeDef& node) {
565   return node.op() == "StatefulPartitionedCall";
566 }
567 
IsStopGradient(const NodeDef & node)568 bool IsStopGradient(const NodeDef& node) {
569   const auto& op = node.op();
570   return op == "StopGradient" || op == "PreventGradient";
571 }
572 
IsStridedSlice(const NodeDef & node)573 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
574 
IsStridedSliceGrad(const NodeDef & node)575 bool IsStridedSliceGrad(const NodeDef& node) {
576   return node.op() == "StridedSliceGrad";
577 }
578 
IsSub(const NodeDef & node)579 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
580 
IsSum(const NodeDef & node)581 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
582 
IsSwitch(const NodeDef & node)583 bool IsSwitch(const NodeDef& node) {
584   const auto& op = node.op();
585   return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
586 }
587 
IsSymbolicGradient(const NodeDef & node)588 bool IsSymbolicGradient(const NodeDef& node) {
589   return node.op() == "SymbolicGradient";
590 }
591 
IsTanh(const NodeDef & node)592 bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
593 
IsTanhGrad(const NodeDef & node)594 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
595 
IsTensorArray(const NodeDef & node)596 bool IsTensorArray(const NodeDef& node) {
597   static const gtl::FlatSet<string>* const kTensorArrayOps =
598       CHECK_NOTNULL((new gtl::FlatSet<string>{
599           "TensorArray",
600           "TensorArrayV2",
601           "TensorArrayV3",
602           "TensorArrayGrad",
603           "TensorArrayGradV2",
604           "TensorArrayGradV3",
605           "TensorArrayGradWithShape",
606           "TensorArrayWrite",
607           "TensorArrayWriteV2",
608           "TensorArrayWriteV3",
609           "TensorArrayRead",
610           "TensorArrayReadV2",
611           "TensorArrayReadV3",
612           "TensorArrayConcat",
613           "TensorArrayConcatV2",
614           "TensorArrayConcatV3",
615           "TensorArraySplit",
616           "TensorArraySplitV2",
617           "TensorArraySplitV3",
618           "TensorArraySize",
619           "TensorArraySizeV2",
620           "TensorArraySizeV3",
621           "TensorArrayClose",
622           "TensorArrayCloseV2",
623           "TensorArrayCloseV3",
624       }));
625   return kTensorArrayOps->count(node.op()) > 0;
626 }
627 
IsTile(const NodeDef & node)628 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
629 
IsTranspose(const NodeDef & node)630 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
631 
IsTruncateDiv(const NodeDef & node)632 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
633 
IsTruncateMod(const NodeDef & node)634 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
635 
IsUnique(const NodeDef & node)636 bool IsUnique(const NodeDef& node) {
637   const auto& op = node.op();
638   return op == "Unique" || op == "UniqueV2";
639 }
640 
IsUnpack(const NodeDef & node)641 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
642 
IsVariable(const NodeDef & node)643 bool IsVariable(const NodeDef& node) {
644   const auto& op = node.op();
645   return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
646          op == "VarHandleOp" || op == "ReadVariableOp" ||
647          op == "_VarHandlesOp" || op == "_ReadVariablesOp";
648 }
649 
IsWhile(const NodeDef & node)650 bool IsWhile(const NodeDef& node) {
651   const auto& op = node.op();
652   return op == "While" || op == "StatelessWhile";
653 }
654 
IsXdivy(const NodeDef & node)655 bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy"; }
656 
IsZerosLike(const NodeDef & node)657 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
658 
IsZeta(const NodeDef & node)659 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
660 
661 namespace {
GetBoolAttr(const NodeDef & node,const string & name)662 bool GetBoolAttr(const NodeDef& node, const string& name) {
663   return node.attr().count(name) > 0 && node.attr().at(name).b();
664 }
665 }  // namespace
666 
IsPersistent(const NodeDef & node)667 bool IsPersistent(const NodeDef& node) {
668   return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
669 }
670 
HasRefInput(const NodeDef & node)671 bool HasRefInput(const NodeDef& node) {
672   const OpDef* op_def;
673   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
674   if (!status.ok()) {
675     return false;
676   }
677   // Nodes such as Assign or AssignAdd modify one of their inputs.
678   for (const auto& input : op_def->input_arg()) {
679     if (input.is_ref()) {
680       return true;
681     }
682   }
683   return false;
684 }
685 
IsDataset(const NodeDef & node)686 bool IsDataset(const NodeDef& node) {
687   const string& op = node.op();
688   // See `GetNodeClassForOp` in core/graph/graph.cc.
689   return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
690          op == "DatasetToSingleElement" || op == "ReduceDataset";
691 }
692 
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)693 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
694   const OpDef* op_def = nullptr;
695   const string& op_name = node.op();
696   Status status = op_registry->LookUpOpDef(op_name, &op_def);
697   if (!status.ok()) {
698     LOG(WARNING) << "Failed to lookup OpDef for " << op_name
699                  << ". Error: " << status.error_message();
700     return false;
701   }
702   return op_def->is_stateful();
703 }
704 
IsStateful(const NodeDef node)705 bool IsStateful(const NodeDef node) {
706   return IsStateful(node, OpRegistry::Global());
707 }
708 
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)709 bool IsFreeOfSideEffect(const NodeDef& node,
710                         const OpRegistryInterface* op_registry) {
711   // Placeholders must be preserved to keep the graph feedable.
712   if (IsPlaceholder(node)) {
713     return false;
714   }
715   const OpDef* op_def = nullptr;
716   const string& op_name = node.op();
717   Status status = op_registry->LookUpOpDef(op_name, &op_def);
718   if (!status.ok()) {
719     return false;
720   }
721   if (op_def->is_stateful()) {
722     return false;
723   }
724   // Nodes such as Assign or AssignAdd modify one of their inputs.
725   for (const auto& input : op_def->input_arg()) {
726     if (input.is_ref()) {
727       return false;
728     }
729   }
730   // Queue ops modify the queue which is a side effect.
731   if (node.op().find("Queue") != string::npos) {
732     return false;
733   }
734   // Sending a tensor via a network is a side effect.
735   if (IsSend(node)) {
736     return false;
737   }
738   return !ModifiesInputsInPlace(node);
739 }
740 
IsFreeOfSideEffect(const NodeDef & node)741 bool IsFreeOfSideEffect(const NodeDef& node) {
742   return IsFreeOfSideEffect(node, OpRegistry::Global());
743 }
744 
ModifiesInputsInPlace(const NodeDef & node)745 bool ModifiesInputsInPlace(const NodeDef& node) {
746   // Some nodes do in-place updates on regular tensor inputs.
747   const string& op_name = node.op();
748 
749   // Ops that modify resource variables effectively modify one of their inputs.
750   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
751       op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
752       op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
753       op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
754       op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
755     return false;
756   }
757 
758   string lower_op_name = op_name;
759   std::transform(lower_op_name.begin(), lower_op_name.end(),
760                  lower_op_name.begin(), ::tolower);
761   if (absl::StrContains(lower_op_name, "inplace")) {
762     return true;
763   }
764   return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
765 }
766 
ModifiesFrameInfo(const NodeDef & node)767 bool ModifiesFrameInfo(const NodeDef& node) {
768   return IsEnter(node) || IsExit(node) || IsNextIteration(node);
769 }
770 
771 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY)                      \
772   bool Is##PROPERTY_CAP(const NodeDef& node) {                             \
773     if (node.op() == "Add") {                                              \
774       /* Workaround for "Add" not being marked is_commutative and */       \
775       /* is_aggregate. (See cl/173915048). */                              \
776       const auto type = GetDataTypeFromAttr(node, "T");                    \
777       return type != DT_INVALID && type != DT_STRING;                      \
778     }                                                                      \
779     const OpDef* op_def = nullptr;                                         \
780     Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
781     return status.ok() && op_def->is_##PROPERTY();                         \
782   }
783 
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)784 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
785 OPDEF_PROPERTY_HELPER(Commutative, commutative)
786 
787 bool IsInvolution(const NodeDef& node) {
788   static const gtl::FlatSet<string>* const kInvolutionOps =
789       CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
790                                               "Neg", "LogicalNot"}));
791   return kInvolutionOps->count(node.op()) > 0;
792 }
793 
IsValueAndOrderAndShapePreserving(const NodeDef & node)794 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
795   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
796     return true;
797   }
798   static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
799       CHECK_NOTNULL((new const gtl::FlatSet<string>{
800           "CheckNumerics",
801           "DebugGradientIdentity",
802           "DeepCopy"
803           "Enter",
804           "Exit",
805           "PreventGradient",
806           "Print",
807           "Snapshot",
808           "StopGradient",
809       }));
810   return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
811          IsIdentity(node);
812 }
813 
IsValueAndOrderPreserving(const NodeDef & node)814 bool IsValueAndOrderPreserving(const NodeDef& node) {
815   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
816     return true;
817   }
818   static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
819       CHECK_NOTNULL((new const gtl::FlatSet<string>{
820           "ExpandDims",
821           "Reshape",
822           "Squeeze",
823       }));
824   return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
825          IsValueAndOrderAndShapePreserving(node);
826 }
827 
IsValuePreserving(const NodeDef & node)828 bool IsValuePreserving(const NodeDef& node) {
829   static const gtl::FlatSet<string>* const kValuePreservingOps =
830       CHECK_NOTNULL((new gtl::FlatSet<string>{
831           "InvertPermutation",
832           "Reverse",
833           "ReverseV2",
834           "Roll",
835           "Transpose",
836           "DepthToSpace",
837           "SpaceToDepth",
838           "BatchToSpace",
839           "BatchToSpaceND",
840           "SpaceToBatch",
841           "SpaceToBatchND",
842       }));
843   return IsValueAndOrderPreserving(node) ||
844          kValuePreservingOps->count(node.op()) > 0;
845 }
846 
IsUnaryElementWise(const NodeDef & node)847 bool IsUnaryElementWise(const NodeDef& node) {
848   static const gtl::FlatSet<string>* const kElementWiseOps =
849       CHECK_NOTNULL((new gtl::FlatSet<string>{
850           "Abs",
851           "Acos",
852           "Acosh",
853           "Asin",
854           "Asinh",
855           "Atan",
856           "Atanh",
857           "Ceil",
858           "ComplexAbs",
859           "Conj",
860           "Cos",
861           "Cosh",
862           "Digamma",
863           "Elu"
864           "Erf",
865           "Erfc",
866           "Exp",
867           "Expm1",
868           "Floor",
869           "Inv",
870           "Invert",
871           "Isinf",
872           "Isnan",
873           "Isfinite",
874           "Lgamma",
875           "Log",
876           "Log1p",
877           "LogicalNot",
878           "Neg",
879           "Reciprocal",
880           "Relu",
881           "Relu6",
882           "Rint",
883           "Round",
884           "Selu",
885           "Rsqrt",
886           "Sigmoid",
887           "Sign",
888           "Sin",
889           "SinH",
890           "Softplus",
891           "Softsign",
892           "Sqrt",
893           "Square",
894           "Tan"
895           "Tanh",
896       }));
897   return kElementWiseOps->count(node.op()) > 0 ||
898          IsValueAndOrderAndShapePreserving(node);
899 }
900 
HasOpDef(const NodeDef & node)901 bool HasOpDef(const NodeDef& node) {
902   const OpDef* op_def = nullptr;
903   return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
904 }
905 
IsIdempotent(const NodeDef & node)906 bool IsIdempotent(const NodeDef& node) {
907   return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
908          !ModifiesFrameInfo(node);
909 }
910 
NeverForwardsInputs(const NodeDef & node)911 bool NeverForwardsInputs(const NodeDef& node) {
912   static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
913       (new gtl::FlatSet<string>{"ArgMax",
914                                 "ArgMin",
915                                 "AudioSpectrogram",
916                                 "AvgPool",
917                                 "BatchMatMul",
918                                 "BatchMatMulV2",
919                                 "BatchNormWithGlobalNormalization",
920                                 "BatchToSpace",
921                                 "BatchToSpaceND",
922                                 "Bincount",
923                                 "BroadcastArgs",
924                                 "BroadcastGradientArgs",
925                                 "Bucketize",
926                                 "CTCBeamSearchDecoder",
927                                 "CTCGreedyDecoder",
928                                 "CTCLoss",
929                                 "CompareAndBitpack",
930                                 "ComplexAbs",
931                                 "Concat",
932                                 "ConcatOffset",
933                                 "ConcatV2",
934                                 "Conv2D",
935                                 "Copy",
936                                 "CopyHost",
937                                 "Cross",
938                                 "CudnnRNN",
939                                 "CudnnRNNBackprop",
940                                 "CudnnRNNBackpropV2",
941                                 "CudnnRNNBackpropV3",
942                                 "CudnnRNNCanonicalToParams",
943                                 "CudnnRNNCanonicalToParamsV2",
944                                 "CudnnRNNParamsSize",
945                                 "CudnnRNNParamsToCanonical",
946                                 "CudnnRNNParamsToCanonicalV2",
947                                 "CudnnRNNV2",
948                                 "CudnnRNNV3",
949                                 "CumProd",
950                                 "CumSum",
951                                 "DebugNanCount",
952                                 "DebugNumericSummary",
953                                 "DecodeProtoV2",
954                                 "DecodeWav",
955                                 "DeepCopy",
956                                 "DepthToSpace",
957                                 "Dequantize",
958                                 "Diag",
959                                 "DiagPart",
960                                 "EditDistance",
961                                 "Empty",
962                                 "EncodeProtoV2",
963                                 "EncodeWav",
964                                 "ExtractImagePatches",
965                                 "ExtractVolumePatches",
966                                 "Fill",
967                                 "Gather",
968                                 "GatherNd",
969                                 "GatherV2",
970                                 "HistogramFixedWidth",
971                                 "InvertPermutation",
972                                 "IsInf",
973                                 "IsNan",
974                                 "Isfinite",
975                                 "LinSpace",
976                                 "LowerBound",
977                                 "MatMul",
978                                 "MatrixDiag",
979                                 "MatrixDiagPart",
980                                 "MatrixDiagPartV2",
981                                 "MatrixDiagV2",
982                                 "Mfcc",
983                                 "Multinomial",
984                                 "OneHot",
985                                 "Pack",
986                                 "ParameterizedTruncatedNormal",
987                                 "PopulationCount",
988                                 "RandomGamma",
989                                 "RandomPoisson",
990                                 "RandomPoissonV2",
991                                 "RandomStandardNormal",
992                                 "RandomUniform",
993                                 "RandomUniformInt",
994                                 "Range",
995                                 "Rank",
996                                 "RequantizationRange",
997                                 "Requantize",
998                                 "ReverseSequence",
999                                 "Shape",
1000                                 "ShapeN",
1001                                 "Size",
1002                                 "SpaceToBatch",
1003                                 "SpaceToBatchND",
1004                                 "SpaceToDepth",
1005                                 "SparseMatMul",
1006                                 "Split",
1007                                 "SplitV",
1008                                 "TruncatedNormal",
1009                                 "Unique",
1010                                 "UniqueV2",
1011                                 "UniqueWithCounts",
1012                                 "UniqueWithCountsV2",
1013                                 "Unpack",
1014                                 "UnravelIndex",
1015                                 "UpperBound",
1016                                 "Where"}));
1017   const string& op_name = node.op();
1018   return kNonForwardingOps->count(op_name) > 0 ||
1019          absl::StrContains(op_name, "Segment") ||
1020          absl::StartsWith(op_name, "Quantize");
1021 }
1022 
IsXlaLaunch(const NodeDef & node)1023 bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
1024 
1025 }  // namespace grappler
1026 }  // end namespace tensorflow
1027