1 
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
18 
19 #include "absl/strings/match.h"
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/op_context.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
34 // TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
35 constexpr int kOpsPerMac = 2;
36 constexpr char kGuaranteeConst[] = "GuaranteeConst";
37 constexpr char kAddN[] = "AddN";
38 constexpr char kBitCast[] = "BitCast";
39 constexpr char kConcatV2[] = "ConcatV2";
40 constexpr char kConv2d[] = "Conv2D";
41 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
42 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
43 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
44 constexpr char kDataFormatVecPermute[] = "DataFormatVecPermute";
45 constexpr char kDepthToSpace[] = "DepthToSpace";
46 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
47 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
48     "DepthwiseConv2dNativeBackpropFilter";
49 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
50     "DepthwiseConv2dNativeBackpropInput";
51 constexpr char kMatMul[] = "MatMul";
52 constexpr char kXlaEinsum[] = "XlaEinsum";
53 constexpr char kEinsum[] = "Einsum";
54 constexpr char kExpandDims[] = "ExpandDims";
55 constexpr char kFill[] = "Fill";
56 constexpr char kSparseMatMul[] = "SparseMatMul";
57 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
58 constexpr char kPlaceholder[] = "Placeholder";
59 constexpr char kIdentity[] = "Identity";
60 constexpr char kIdentityN[] = "IdentityN";
61 constexpr char kRefIdentity[] = "RefIdentity";
62 constexpr char kNoOp[] = "NoOp";
63 constexpr char kReshape[] = "Reshape";
64 constexpr char kSplit[] = "Split";
65 constexpr char kSqueeze[] = "Squeeze";
66 constexpr char kRecv[] = "_Recv";
67 constexpr char kSend[] = "_Send";
68 constexpr char kBatchMatMul[] = "BatchMatMul";
69 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
70 constexpr char kOneHot[] = "OneHot";
71 constexpr char kPack[] = "Pack";
72 constexpr char kRank[] = "Rank";
73 constexpr char kRange[] = "Range";
74 constexpr char kShape[] = "Shape";
75 constexpr char kShapeN[] = "ShapeN";
76 constexpr char kSize[] = "Size";
77 constexpr char kStopGradient[] = "StopGradient";
78 constexpr char kPreventGradient[] = "PreventGradient";
79 constexpr char kGather[] = "Gather";
80 constexpr char kGatherNd[] = "GatherNd";
81 constexpr char kGatherV2[] = "GatherV2";
82 constexpr char kScatterAdd[] = "ScatterAdd";
83 constexpr char kScatterDiv[] = "ScatterDiv";
84 constexpr char kScatterMax[] = "ScatterMax";
85 constexpr char kScatterMin[] = "ScatterMin";
86 constexpr char kScatterMul[] = "ScatterMul";
87 constexpr char kScatterSub[] = "ScatterSub";
88 constexpr char kScatterUpdate[] = "ScatterUpdate";
89 constexpr char kSlice[] = "Slice";
90 constexpr char kStridedSlice[] = "StridedSlice";
91 constexpr char kSpaceToDepth[] = "SpaceToDepth";
92 constexpr char kTranspose[] = "Transpose";
93 constexpr char kTile[] = "Tile";
94 constexpr char kMaxPool[] = "MaxPool";
95 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
96 constexpr char kAvgPool[] = "AvgPool";
97 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
98 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
99 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
100 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
101 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
102 constexpr char kUnpack[] = "Unpack";
103 constexpr char kSoftmax[] = "Softmax";
104 constexpr char kResizeBilinear[] = "ResizeBilinear";
105 constexpr char kCropAndResize[] = "CropAndResize";
106 // Dynamic control flow ops.
107 constexpr char kSwitch[] = "Switch";
108 constexpr char kMerge[] = "Merge";
109 constexpr char kEnter[] = "Enter";
110 constexpr char kExit[] = "Exit";
111 constexpr char kNextIteration[] = "NextIteration";
112 // Persistent ops.
113 constexpr char kConst[] = "Const";
114 constexpr char kVariable[] = "Variable";
115 constexpr char kVariableV2[] = "VariableV2";
116 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
117 constexpr char kVarHandleOp[] = "VarHandleOp";
118 constexpr char kVarHandlesOp[] = "_VarHandlesOp";
119 constexpr char kReadVariableOp[] = "ReadVariableOp";
120 constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
121 constexpr char kAssignVariableOp[] = "AssignVariableOp";
122 constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
123 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
124 
125 static const Costs::Duration kMinComputeTime(1);
126 static const int64 kMinComputeOp = 1;
127 
128 namespace {
129 
GetDataFormat(const OpInfo & op_info)130 std::string GetDataFormat(const OpInfo& op_info) {
131   std::string data_format = "NHWC";  // Default format.
132   if (op_info.attr().find("data_format") != op_info.attr().end()) {
133     data_format = op_info.attr().at("data_format").s();
134   }
135   return data_format;
136 }
137 
GetFilterFormat(const OpInfo & op_info)138 std::string GetFilterFormat(const OpInfo& op_info) {
139   std::string filter_format = "HWIO";  // Default format.
140   if (op_info.attr().find("filter_format") != op_info.attr().end()) {
141     filter_format = op_info.attr().at("filter_format").s();
142   }
143   return filter_format;
144 }
145 
GetPadding(const OpInfo & op_info)146 Padding GetPadding(const OpInfo& op_info) {
147   if (op_info.attr().find("padding") != op_info.attr().end() &&
148       op_info.attr().at("padding").s() == "VALID") {
149     return Padding::VALID;
150   }
151   return Padding::SAME;  // Default padding.
152 }
153 
IsTraining(const OpInfo & op_info)154 bool IsTraining(const OpInfo& op_info) {
155   if (op_info.attr().find("is_training") != op_info.attr().end() &&
156       op_info.attr().at("is_training").b()) {
157     return true;
158   }
159   return false;
160 }
161 
162 // TODO(dyoon): support non-4D tensors in the cost functions of convolution
163 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
164 // helper functions.
GetStrides(const OpInfo & op_info)165 std::vector<int64> GetStrides(const OpInfo& op_info) {
166   if (op_info.attr().find("strides") != op_info.attr().end()) {
167     const auto strides = op_info.attr().at("strides").list().i();
168     DCHECK(strides.size() == 4)
169         << "Attr strides is not a length-4 vector: " << op_info.DebugString();
170     if (strides.size() != 4) return {1, 1, 1, 1};
171     return {strides[0], strides[1], strides[2], strides[3]};
172   }
173   return {1, 1, 1, 1};
174 }
175 
GetKernelSize(const OpInfo & op_info)176 std::vector<int64> GetKernelSize(const OpInfo& op_info) {
177   if (op_info.attr().find("ksize") != op_info.attr().end()) {
178     const auto ksize = op_info.attr().at("ksize").list().i();
179     DCHECK(ksize.size() == 4)
180         << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
181     if (ksize.size() != 4) return {1, 1, 1, 1};
182     return {ksize[0], ksize[1], ksize[2], ksize[3]};
183   }
184   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
185   // {1, 1, 1, 1} in that case.
186   return {1, 1, 1, 1};
187 }
188 
GetOutputSize(const int64 input,const int64 filter,const int64 stride,const Padding & padding)189 int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
190                     const Padding& padding) {
191   // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
192   // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
193   if (padding == Padding::VALID) {
194     return (input - filter + stride) / stride;
195   } else {  // SAME.
196     return (input + stride - 1) / stride;
197   }
198 }
199 
200 // Return the output element count of a multi-input element-wise op considering
201 // broadcasting.
CwiseOutputElementCount(const OpInfo & op_info)202 int64 CwiseOutputElementCount(const OpInfo& op_info) {
203   int max_rank = 1;
204   for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
205     max_rank = std::max(max_rank, input_properties.shape().dim_size());
206   }
207 
208   TensorShapeProto output_shape;
209   output_shape.mutable_dim()->Reserve(max_rank);
210   for (int i = 0; i < max_rank; ++i) {
211     output_shape.add_dim();
212   }
213 
214   // Expand the shape of the output to follow the numpy-style broadcast rule
215   // which matches each input starting with the trailing dimensions and working
216   // its way forward. To do this, iterate through each input shape's dimensions
217   // in reverse order, and potentially increase the corresponding output
218   // dimension.
219   for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
220     const TensorShapeProto& input_shape = input_properties.shape();
221     for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
222       int output_shape_dim_index =
223           i + output_shape.dim_size() - input_shape.dim_size();
224       output_shape.mutable_dim(output_shape_dim_index)
225           ->set_size(std::max(output_shape.dim(output_shape_dim_index).size(),
226                               input_shape.dim(i).size()));
227     }
228   }
229 
230   int64 count = 1;
231   for (int i = 0; i < output_shape.dim_size(); i++) {
232     count *= output_shape.dim(i).size();
233   }
234   return count;
235 }
236 
237 // Helper function for determining whether there are repeated indices in the
238 // input Einsum equation.
CheckRepeatedDimensions(const absl::string_view dim_str)239 bool CheckRepeatedDimensions(const absl::string_view dim_str) {
240   int str_size = dim_str.size();
241   for (int idx = 0; idx < str_size - 1; idx++) {
242     if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
243       return true;
244     }
245   }
246   return false;
247 }
248 
249 // Auxiliary function for determining whether OpLevelCostEstimator is compatible
250 // with a given Einsum.
IsEinsumCorrectlyFormed(const OpContext & einsum_context)251 bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
252   const auto& op_info = einsum_context.op_info;
253 
254   auto it = op_info.attr().find("equation");
255   if (it == op_info.attr().end()) return false;
256   const absl::string_view equation = it->second.s();
257   std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
258 
259   if (equation_split.empty()) {
260     LOG(WARNING) << "Einsum with malformed equation";
261     return false;
262   }
263   std::vector<absl::string_view> input_split =
264       absl::StrSplit(equation_split[0], ',');
265 
266   // The current model covers Einsum operations with two operands and a RHS
267   if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
268     VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
269     return false;
270   }
271   const auto& a_input = op_info.inputs(0);
272   const auto& b_input = op_info.inputs(1);
273   absl::string_view rhs_str = equation_split[1];
274   absl::string_view a_input_str = input_split[0];
275   absl::string_view b_input_str = input_split[1];
276 
277   // Ellipsis are not currently supported
278   if (absl::StrContains(a_input_str, "...") ||
279       absl::StrContains(b_input_str, "...")) {
280     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
281             << ", ellipsis not supported";
282     return false;
283   }
284 
285   constexpr int kMatrixRank = 2;
286 
287   bool a_input_shape_unknown = false;
288   bool b_input_shape_unknown = false;
289 
290   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
291       a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
292       &a_input_shape_unknown);
293   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
294       b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
295       &b_input_shape_unknown);
296 
297   if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
298       b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
299     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
300             << ", equation subscripts don't match tensor rank.";
301     return false;
302   }
303 
304   // Subscripts where axis appears more than once for a single input are not yet
305   // supported
306   if (CheckRepeatedDimensions(a_input_str) ||
307       CheckRepeatedDimensions(b_input_str) ||
308       CheckRepeatedDimensions(rhs_str)) {
309     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
310             << ", Subscripts where axis appears more than once for a single "
311                "input are not yet supported";
312     return false;
313   }
314 
315   return true;
316 }
317 
318 }  // namespace
319 
320 // Return a minimum shape if the shape is unknown. If known, return the original
321 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)322 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
323                                       int rank, bool* found_unknown_shapes) {
324   auto shape = original_shape;
325   bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
326 
327   if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
328     *found_unknown_shapes = true;
329     VLOG(2) << "Use minimum shape because the rank is unknown.";
330     // The size of each dimension is at least 1, if unknown.
331     for (int i = shape.dim_size(); i < rank; i++) {
332       shape.add_dim()->set_size(1);
333     }
334   } else if (is_scalar) {
335     for (int i = 0; i < rank; i++) {
336       shape.add_dim()->set_size(1);
337     }
338   } else if (shape.dim_size() > rank) {
339     *found_unknown_shapes = true;
340     shape.clear_dim();
341     for (int i = 0; i < rank; i++) {
342       shape.add_dim()->set_size(original_shape.dim(i).size());
343     }
344   } else {
345     for (int i = 0; i < shape.dim_size(); i++) {
346       if (shape.dim(i).size() < 0) {
347         *found_unknown_shapes = true;
348         VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
349         // The size of each dimension is at least 1, if unknown.
350         shape.mutable_dim(i)->set_size(1);
351       }
352     }
353   }
354   return shape;
355 }
356 
OpLevelCostEstimator()357 OpLevelCostEstimator::OpLevelCostEstimator() {
358   // Syntactic sugar to build and return a lambda that takes an OpInfo and
359   // returns a cost.
360   typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
361                                                    NodeCosts*) const;
362   auto wrap = [this](CostImpl impl)
363       -> std::function<Status(const OpContext&, NodeCosts*)> {
364     return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
365       return (this->*impl)(op_context, node_costs);
366     };
367   };
368 
369   device_cost_impl_.emplace(kConv2d,
370                             wrap(&OpLevelCostEstimator::PredictConv2D));
371   device_cost_impl_.emplace(
372       kConv2dBackpropFilter,
373       wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
374   device_cost_impl_.emplace(
375       kConv2dBackpropInput,
376       wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
377   device_cost_impl_.emplace(
378       kFusedConv2dBiasActivation,
379       wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation));
380   // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
381   // same although the actual meaning of the parameters are different. See
382   // comments in PredictConv2D and related functions
383   device_cost_impl_.emplace(kDepthwiseConv2dNative,
384                             wrap(&OpLevelCostEstimator::PredictConv2D));
385   device_cost_impl_.emplace(
386       kDepthwiseConv2dNativeBackpropFilter,
387       wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
388   device_cost_impl_.emplace(
389       kDepthwiseConv2dNativeBackpropInput,
390       wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
391   device_cost_impl_.emplace(kMatMul,
392                             wrap(&OpLevelCostEstimator::PredictMatMul));
393   device_cost_impl_.emplace(kSparseMatMul,
394                             wrap(&OpLevelCostEstimator::PredictMatMul));
395   device_cost_impl_.emplace(
396       kSparseTensorDenseMatMul,
397       wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul));
398   device_cost_impl_.emplace(kBatchMatMul,
399                             wrap(&OpLevelCostEstimator::PredictBatchMatMul));
400   device_cost_impl_.emplace(kBatchMatMulV2,
401                             wrap(&OpLevelCostEstimator::PredictBatchMatMul));
402   device_cost_impl_.emplace(kQuantizedMatMul,
403                             wrap(&OpLevelCostEstimator::PredictMatMul));
404   device_cost_impl_.emplace(kQuantizedMatMulV2,
405                             wrap(&OpLevelCostEstimator::PredictMatMul));
406   device_cost_impl_.emplace(kXlaEinsum,
407                             wrap(&OpLevelCostEstimator::PredictEinsum));
408   device_cost_impl_.emplace(kEinsum,
409                             wrap(&OpLevelCostEstimator::PredictEinsum));
410 
411   device_cost_impl_.emplace(kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp));
412   device_cost_impl_.emplace(kGuaranteeConst,
413                             wrap(&OpLevelCostEstimator::PredictNoOp));
414 
415   device_cost_impl_.emplace(kGather,
416                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
417   device_cost_impl_.emplace(kGatherNd,
418                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
419   device_cost_impl_.emplace(kGatherV2,
420                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
421   device_cost_impl_.emplace(kScatterAdd,
422                             wrap(&OpLevelCostEstimator::PredictScatter));
423   device_cost_impl_.emplace(kScatterDiv,
424                             wrap(&OpLevelCostEstimator::PredictScatter));
425   device_cost_impl_.emplace(kScatterMax,
426                             wrap(&OpLevelCostEstimator::PredictScatter));
427   device_cost_impl_.emplace(kScatterMin,
428                             wrap(&OpLevelCostEstimator::PredictScatter));
429   device_cost_impl_.emplace(kScatterMul,
430                             wrap(&OpLevelCostEstimator::PredictScatter));
431   device_cost_impl_.emplace(kScatterSub,
432                             wrap(&OpLevelCostEstimator::PredictScatter));
433   device_cost_impl_.emplace(kScatterUpdate,
434                             wrap(&OpLevelCostEstimator::PredictScatter));
435 
436   device_cost_impl_.emplace(kSlice,
437                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
438   device_cost_impl_.emplace(kStridedSlice,
439                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
440 
441   device_cost_impl_.emplace(kPlaceholder,
442                             wrap(&OpLevelCostEstimator::PredictIdentity));
443   device_cost_impl_.emplace(kIdentity,
444                             wrap(&OpLevelCostEstimator::PredictIdentity));
445   device_cost_impl_.emplace(kIdentityN,
446                             wrap(&OpLevelCostEstimator::PredictIdentity));
447   device_cost_impl_.emplace(kRefIdentity,
448                             wrap(&OpLevelCostEstimator::PredictIdentity));
449   device_cost_impl_.emplace(kStopGradient,
450                             wrap(&OpLevelCostEstimator::PredictIdentity));
451   device_cost_impl_.emplace(kPreventGradient,
452                             wrap(&OpLevelCostEstimator::PredictIdentity));
453   device_cost_impl_.emplace(kReshape,
454                             wrap(&OpLevelCostEstimator::PredictIdentity));
455   device_cost_impl_.emplace(kRecv,
456                             wrap(&OpLevelCostEstimator::PredictIdentity));
457   device_cost_impl_.emplace(kSend,
458                             wrap(&OpLevelCostEstimator::PredictIdentity));
459   device_cost_impl_.emplace(kSwitch,
460                             wrap(&OpLevelCostEstimator::PredictIdentity));
461   device_cost_impl_.emplace(kMerge,
462                             wrap(&OpLevelCostEstimator::PredictIdentity));
463   device_cost_impl_.emplace(kEnter,
464                             wrap(&OpLevelCostEstimator::PredictIdentity));
465   device_cost_impl_.emplace(kExit,
466                             wrap(&OpLevelCostEstimator::PredictIdentity));
467   device_cost_impl_.emplace(kNextIteration,
468                             wrap(&OpLevelCostEstimator::PredictIdentity));
469   device_cost_impl_.emplace(kBitCast,
470                             wrap(&OpLevelCostEstimator::PredictIdentity));
471 
472   device_cost_impl_.emplace(kConcatV2,
473                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
474   device_cost_impl_.emplace(kDataFormatVecPermute,
475                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
476   device_cost_impl_.emplace(kDepthToSpace,
477                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
478   device_cost_impl_.emplace(kExpandDims,
479                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
480   device_cost_impl_.emplace(kFill,
481                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
482   device_cost_impl_.emplace(kOneHot,
483                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
484   device_cost_impl_.emplace(kPack,
485                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
486   device_cost_impl_.emplace(kRange,
487                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
488   device_cost_impl_.emplace(kSpaceToDepth,
489                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
490   device_cost_impl_.emplace(kSplit,
491                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
492   device_cost_impl_.emplace(kSqueeze,
493                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
494   device_cost_impl_.emplace(kTranspose,
495                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
496   device_cost_impl_.emplace(kTile,
497                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
498   device_cost_impl_.emplace(kUnpack,
499                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
500 
501   device_cost_impl_.emplace(kRank,
502                             wrap(&OpLevelCostEstimator::PredictMetadata));
503   device_cost_impl_.emplace(kShape,
504                             wrap(&OpLevelCostEstimator::PredictMetadata));
505   device_cost_impl_.emplace(kShapeN,
506                             wrap(&OpLevelCostEstimator::PredictMetadata));
507   device_cost_impl_.emplace(kSize,
508                             wrap(&OpLevelCostEstimator::PredictMetadata));
509   device_cost_impl_.emplace(kMaxPool,
510                             wrap(&OpLevelCostEstimator::PredictMaxPool));
511   device_cost_impl_.emplace(kMaxPoolGrad,
512                             wrap(&OpLevelCostEstimator::PredictMaxPoolGrad));
513   device_cost_impl_.emplace(kAvgPool,
514                             wrap(&OpLevelCostEstimator::PredictAvgPool));
515   device_cost_impl_.emplace(kAvgPoolGrad,
516                             wrap(&OpLevelCostEstimator::PredictAvgPoolGrad));
517   device_cost_impl_.emplace(kFusedBatchNorm,
518                             wrap(&OpLevelCostEstimator::PredictFusedBatchNorm));
519   device_cost_impl_.emplace(
520       kFusedBatchNormGrad,
521       wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
522   device_cost_impl_.emplace(kSoftmax,
523                             wrap(&OpLevelCostEstimator::PredictSoftmax));
524   device_cost_impl_.emplace(kResizeBilinear,
525                             wrap(&OpLevelCostEstimator::PredictResizeBilinear));
526   device_cost_impl_.emplace(kCropAndResize,
527                             wrap(&OpLevelCostEstimator::PredictCropAndResize));
528   device_cost_impl_.emplace(
529       kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
530   device_cost_impl_.emplace(
531       kAssignAddVariableOp,
532       wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
533   device_cost_impl_.emplace(
534       kAssignSubVariableOp,
535       wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
536   device_cost_impl_.emplace(kAddN, wrap(&OpLevelCostEstimator::PredictNaryOp));
537 
538   persistent_ops_ = {
539       kConst,       kVariable,       kVariableV2,   kAutoReloadVariable,
540       kVarHandleOp, kReadVariableOp, kVarHandlesOp, kReadVariablesOp};
541 
542 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
543 
544   // Quantize = apply min and max bounds, multiply by scale factor and round.
545   const int quantize_v2_cost =
546       EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
547       EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
548   const int quantize_and_dequantize_v2_cost =
549       quantize_v2_cost + EIGEN_COST(scalar_product_op<float>);
550 
551   // Unary ops alphabetically sorted
552   elementwise_ops_.emplace("Acos", EIGEN_COST(scalar_acos_op<float>));
553   elementwise_ops_.emplace("All", EIGEN_COST(scalar_boolean_and_op));
554   elementwise_ops_.emplace("ArgMax", EIGEN_COST(scalar_max_op<float>));
555   elementwise_ops_.emplace("Asin", EIGEN_COST(scalar_asin_op<float>));
556   elementwise_ops_.emplace("Atan", EIGEN_COST(scalar_atan_op<float>));
557   elementwise_ops_.emplace("Atan2", EIGEN_COST(scalar_quotient_op<float>) +
558                                         EIGEN_COST(scalar_atan_op<float>));
559   // For now, we use Eigen cost model for float to int16 cast as an example
560   // case; Eigen cost model is zero when src and dst types are identical,
561   // and it uses AddCost (1) when different. We may implement a separate
562   // cost functions for cast ops, using the actual input and output types.
563   elementwise_ops_.emplace(
564       "Cast", Eigen::internal::functor_traits<
565                   Eigen::internal::scalar_cast_op<float, int16>>::Cost);
566   elementwise_ops_.emplace("Ceil", EIGEN_COST(scalar_ceil_op<float>));
567   elementwise_ops_.emplace("Cos", EIGEN_COST(scalar_cos_op<float>));
568   elementwise_ops_.emplace("Dequantize", EIGEN_COST(scalar_product_op<float>));
569   elementwise_ops_.emplace("Erf", 1);
570   elementwise_ops_.emplace("Erfc", 1);
571   elementwise_ops_.emplace("Exp", EIGEN_COST(scalar_exp_op<float>));
572   elementwise_ops_.emplace("Expm1", EIGEN_COST(scalar_expm1_op<float>));
573   elementwise_ops_.emplace("Floor", EIGEN_COST(scalar_floor_op<float>));
574   elementwise_ops_.emplace("Inv", EIGEN_COST(scalar_inverse_op<float>));
575   elementwise_ops_.emplace("InvGrad", 1);
576   elementwise_ops_.emplace("Lgamma", 1);
577   elementwise_ops_.emplace("Log", EIGEN_COST(scalar_log_op<float>));
578   elementwise_ops_.emplace("Log1p", EIGEN_COST(scalar_log1p_op<float>));
579   elementwise_ops_.emplace("Max", EIGEN_COST(scalar_max_op<float>));
580   elementwise_ops_.emplace("Min", EIGEN_COST(scalar_min_op<float>));
581   elementwise_ops_.emplace("Neg", EIGEN_COST(scalar_opposite_op<float>));
582   elementwise_ops_.emplace("Prod", EIGEN_COST(scalar_product_op<float>));
583   elementwise_ops_.emplace("QuantizeAndDequantizeV2",
584                            quantize_and_dequantize_v2_cost);
585   elementwise_ops_.emplace("QuantizedSigmoid",
586                            EIGEN_COST(scalar_logistic_op<float>));
587   elementwise_ops_.emplace("QuantizeV2", quantize_v2_cost);
588   elementwise_ops_.emplace("Reciprocal", EIGEN_COST(scalar_inverse_op<float>));
589   elementwise_ops_.emplace("Relu", EIGEN_COST(scalar_max_op<float>));
590   elementwise_ops_.emplace("Relu6", EIGEN_COST(scalar_max_op<float>));
591   elementwise_ops_.emplace("Rint", 1);
592   elementwise_ops_.emplace("Round", EIGEN_COST(scalar_round_op<float>));
593   elementwise_ops_.emplace("Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>));
594   elementwise_ops_.emplace("Sigmoid", EIGEN_COST(scalar_logistic_op<float>));
595   elementwise_ops_.emplace("Sign", EIGEN_COST(scalar_sign_op<float>));
596   elementwise_ops_.emplace("Sin", EIGEN_COST(scalar_sin_op<float>));
597   elementwise_ops_.emplace("Sqrt", EIGEN_COST(scalar_sqrt_op<float>));
598   elementwise_ops_.emplace("Square", EIGEN_COST(scalar_square_op<float>));
599   elementwise_ops_.emplace("Sum", EIGEN_COST(scalar_sum_op<float>));
600   elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
601   elementwise_ops_.emplace("Tanh", EIGEN_COST(scalar_tanh_op<float>));
602   elementwise_ops_.emplace("TopKV2", EIGEN_COST(scalar_max_op<float>));
603   // Binary ops alphabetically sorted
604   elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
605   elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
606   elementwise_ops_.emplace("ApproximateEqual", 1);
607   elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
608   elementwise_ops_.emplace("QuantizedBiasAdd",
609                            EIGEN_COST(scalar_sum_op<float>));
610   elementwise_ops_.emplace("Div", EIGEN_COST(scalar_quotient_op<float>));
611   elementwise_ops_.emplace("Equal", 1);
612   elementwise_ops_.emplace("FloorDiv", EIGEN_COST(scalar_quotient_op<float>));
613   elementwise_ops_.emplace("FloorMod", EIGEN_COST(scalar_mod_op<float>));
614   elementwise_ops_.emplace("Greater", 1);
615   elementwise_ops_.emplace("GreaterEqual", 1);
616   elementwise_ops_.emplace("Less", 1);
617   elementwise_ops_.emplace("LessEqual", 1);
618   elementwise_ops_.emplace("LogicalAnd", EIGEN_COST(scalar_boolean_and_op));
619   elementwise_ops_.emplace("LogicalNot", 1);
620   elementwise_ops_.emplace("LogicalOr", EIGEN_COST(scalar_boolean_or_op));
621   elementwise_ops_.emplace("Maximum", EIGEN_COST(scalar_max_op<float>));
622   elementwise_ops_.emplace("Minimum", EIGEN_COST(scalar_min_op<float>));
623   elementwise_ops_.emplace("Mod", EIGEN_COST(scalar_mod_op<float>));
624   elementwise_ops_.emplace("Mul", EIGEN_COST(scalar_product_op<float>));
625   elementwise_ops_.emplace("NotEqual", 1);
626   elementwise_ops_.emplace("QuantizedAdd", EIGEN_COST(scalar_sum_op<float>));
627   elementwise_ops_.emplace("QuantizedMul",
628                            EIGEN_COST(scalar_product_op<float>));
629   elementwise_ops_.emplace("RealDiv", EIGEN_COST(scalar_quotient_op<float>));
630   elementwise_ops_.emplace("ReluGrad", EIGEN_COST(scalar_max_op<float>));
631   elementwise_ops_.emplace("Select", EIGEN_COST(scalar_boolean_or_op));
632   elementwise_ops_.emplace("SelectV2", EIGEN_COST(scalar_boolean_or_op));
633   elementwise_ops_.emplace("SquaredDifference",
634                            EIGEN_COST(scalar_square_op<float>) +
635                                EIGEN_COST(scalar_difference_op<float>));
636   elementwise_ops_.emplace("Sub", EIGEN_COST(scalar_difference_op<float>));
637   elementwise_ops_.emplace("TruncateDiv",
638                            EIGEN_COST(scalar_quotient_op<float>));
639   elementwise_ops_.emplace("TruncateMod", EIGEN_COST(scalar_mod_op<float>));
640   elementwise_ops_.emplace("Where", 1);
641 
642 #undef EIGEN_COST
643 
644   // By default, use sum of memory_time and compute_time for execution_time.
645   compute_memory_overlap_ = false;
646 }
647 
PredictCosts(const OpContext & op_context) const648 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
649   Costs costs;
650   NodeCosts node_costs;
651   if (PredictNodeCosts(op_context, &node_costs).ok()) {
652     if (node_costs.has_costs) {
653       return node_costs.costs;
654     }
655     // Convert NodeCosts to Costs.
656     if (node_costs.minimum_cost_op) {
657       // Override to minimum cost; Note that some ops with minimum cost may have
658       // non-typical device (e.g., channel for _Send), which may fail with
659       // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
660       // directly set minimum values to Costs here, not calling
661       // PredictOpCountBasedCost().
662       costs.compute_time = kMinComputeTime;
663       costs.execution_time = kMinComputeTime;
664       costs.memory_time = 0;
665       costs.intermediate_memory_time = 0;
666       costs.intermediate_memory_read_time = 0;
667       costs.intermediate_memory_write_time = 0;
668     } else {
669       // Convert NodeCosts to Costs.
670       costs = PredictOpCountBasedCost(
671           node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
672           node_costs.num_total_write_bytes(), op_context.op_info);
673     }
674     VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
675             << costs.execution_time.count() << " ns.";
676     // Copy additional stats from NodeCosts to Costs.
677     costs.max_memory = node_costs.max_memory;
678     costs.persistent_memory = node_costs.persistent_memory;
679     costs.temporary_memory = node_costs.temporary_memory;
680     costs.inaccurate = node_costs.inaccurate;
681     costs.num_ops_with_unknown_shapes =
682         node_costs.num_nodes_with_unknown_shapes;
683     costs.num_ops_total = node_costs.num_nodes;
684     return costs;
685   }
686   // Errors during node cost estimate.
687   LOG(WARNING) << "Error in PredictCost() for the op: "
688                << op_context.op_info.ShortDebugString();
689   costs = Costs::ZeroCosts(/*inaccurate=*/true);
690   costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
691   return costs;
692 }
693 
PredictNodeCosts(const OpContext & op_context,NodeCosts * node_costs) const694 Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
695                                               NodeCosts* node_costs) const {
696   const auto& op_info = op_context.op_info;
697   auto it = device_cost_impl_.find(op_info.op());
698   if (it != device_cost_impl_.end()) {
699     std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
700     return estimator(op_context, node_costs);
701   }
702 
703   if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
704     return PredictVariable(op_context, node_costs);
705   }
706 
707   if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
708     return PredictCwiseOp(op_context, node_costs);
709   }
710 
711   VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
712 
713   node_costs->num_nodes_with_unknown_op_type = 1;
714   return PredictCostOfAnUnknownOp(op_context, node_costs);
715 }
716 
717 // This method assumes a typical system composed of CPUs and GPUs, connected
718 // through PCIe. To define device info more precisely, override this method.
GetDeviceInfo(const DeviceProperties & device) const719 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
720     const DeviceProperties& device) const {
721   double gflops = -1;
722   double gb_per_sec = -1;
723 
724   if (device.type() == "CPU") {
725     // Check if vector instructions are available, and refine performance
726     // prediction based on this.
727     // Frequencies are stored in MHz in the DeviceProperties.
728     gflops = device.num_cores() * device.frequency() * 1e-3;
729     if (gb_per_sec < 0) {
730       if (device.bandwidth() > 0) {
731         gb_per_sec = device.bandwidth() / 1e6;
732       } else {
733         gb_per_sec = 32;
734       }
735     }
736   } else if (device.type() == "GPU") {
737     const std::string architecture = device.environment().at("architecture");
738     int cores_per_multiprocessor;
739     if (architecture < "3") {
740       // Fermi
741       cores_per_multiprocessor = 32;
742     } else if (architecture < "4") {
743       // Kepler
744       cores_per_multiprocessor = 192;
745     } else if (architecture < "6") {
746       // Maxwell
747       cores_per_multiprocessor = 128;
748     } else {
749       // Pascal (compute capability version 6) and Volta (compute capability
750       // version 7)
751       cores_per_multiprocessor = 64;
752     }
753     gflops = device.num_cores() * device.frequency() * 1e-3 *
754              cores_per_multiprocessor * kOpsPerMac;
755     if (device.bandwidth() > 0) {
756       gb_per_sec = device.bandwidth() / 1e6;
757     } else {
758       gb_per_sec = 100;
759     }
760   } else {
761     LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
762                                << ", assuming PCIe between CPU and GPU.";
763     gflops = 1;  // Dummy value; data transfer ops would not have compute ops.
764     gb_per_sec = 12;  // default PCIe x16 gen3.
765   }
766   VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
767           << " gb_per_sec: " << gb_per_sec;
768 
769   return DeviceInfo(gflops, gb_per_sec);
770 }
771 
PredictCwiseOp(const OpContext & op_context,NodeCosts * node_costs) const772 Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
773                                             NodeCosts* node_costs) const {
774   const auto& op_info = op_context.op_info;
775   bool found_unknown_shapes = false;
776   // For element-wise operations, op count is the element count of any input. We
777   // use the count for the largest input here to be more robust in case that the
778   // shape is unknown or partially known for other input.
779   int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
780   // If output shape is available, try to use the element count calculated from
781   // that.
782   if (op_info.outputs_size() > 0) {
783     op_count = std::max(
784         op_count,
785         CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
786   }
787   // Calculate the output shape possibly resulting from broadcasting.
788   if (op_info.inputs_size() >= 2) {
789     op_count = std::max(op_count, CwiseOutputElementCount(op_info));
790   }
791 
792   int op_cost = 1;
793   auto it = elementwise_ops_.find(op_info.op());
794   if (it != elementwise_ops_.end()) {
795     op_cost = it->second;
796   } else {
797     return errors::InvalidArgument("Not a cwise op: ", op_info.op());
798   }
799 
800   return PredictDefaultNodeCosts(op_count * op_cost, op_context,
801                                  &found_unknown_shapes, node_costs);
802 }
803 
PredictCostOfAnUnknownOp(const OpContext & op_context,NodeCosts * node_costs) const804 Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
805     const OpContext& op_context, NodeCosts* node_costs) const {
806   // Don't assume the operation is cwise, return cost based on input/output size
807   // and admit that it is inaccurate...
808   bool found_unknown_shapes = false;
809   node_costs->inaccurate = true;
810   return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
811                                  node_costs);
812 }
813 
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const814 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
815     double operations, const OpInfo& op_info) const {
816   bool unknown_shapes = false;
817   const double input_size = CalculateInputSize(op_info, &unknown_shapes);
818   const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
819   Costs costs =
820       PredictOpCountBasedCost(operations, input_size, output_size, op_info);
821   costs.inaccurate = unknown_shapes;
822   costs.num_ops_with_unknown_shapes = unknown_shapes;
823   costs.max_memory = output_size;
824   return costs;
825 }
826 
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const827 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
828     double operations, double input_io_bytes, double output_io_bytes,
829     const OpInfo& op_info) const {
830   double total_io_bytes = input_io_bytes + output_io_bytes;
831   const DeviceInfo device_info = GetDeviceInfo(op_info.device());
832   if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
833       device_info.intermediate_read_gb_per_sec <= 0 ||
834       device_info.intermediate_write_gb_per_sec <= 0) {
835     VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
836             << " device type:" << op_info.device().type()
837             << " device model:" << op_info.device().model();
838   }
839 
840   Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
841   VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
842           << " Compute Time (ns):" << compute_cost.count();
843 
844   Costs::NanoSeconds memory_cost(
845       std::ceil(total_io_bytes / device_info.gb_per_sec));
846   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
847           << " Memory Time (ns):" << memory_cost.count();
848 
849   // Check if bytes > 0.  If it's not and the bandwidth is set to infinity
850   // then the result would be undefined.
851   double intermediate_read_time =
852       (input_io_bytes > 0)
853           ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
854           : 0;
855 
856   double intermediate_write_time =
857       (output_io_bytes > 0)
858           ? std::ceil(output_io_bytes /
859                       device_info.intermediate_write_gb_per_sec)
860           : 0;
861 
862   Costs::NanoSeconds intermediate_memory_cost =
863       compute_memory_overlap_
864           ? std::max(intermediate_read_time, intermediate_write_time)
865           : (intermediate_read_time + intermediate_write_time);
866   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
867           << " Intermediate Memory Time (ns):"
868           << intermediate_memory_cost.count();
869 
870   Costs costs = Costs::ZeroCosts();
871   costs.compute_time = compute_cost;
872   costs.memory_time = memory_cost;
873   costs.intermediate_memory_time = intermediate_memory_cost;
874   costs.intermediate_memory_read_time =
875       Costs::NanoSeconds(intermediate_read_time);
876   costs.intermediate_memory_write_time =
877       Costs::NanoSeconds(intermediate_write_time);
878   CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
879   return costs;
880 }
881 
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes)882 int64 OpLevelCostEstimator::CountConv2DOperations(const OpInfo& op_info,
883                                                   bool* found_unknown_shapes) {
884   return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
885 }
886 
887 // Helper to translate the positional arguments into named fields.
888 /* static */
889 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)890 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
891     const TensorShapeProto& original_image_shape,
892     const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
893     bool* found_unknown_shapes) {
894   VLOG(2) << "op features: " << op_info.DebugString();
895   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
896   VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
897 
898   int x_index, y_index, major_channel_index, minor_channel_index = -1;
899   const std::string& data_format = GetDataFormat(op_info);
900   if (data_format == "NCHW") {
901     major_channel_index = 1;
902     y_index = 2;
903     x_index = 3;
904   } else if (data_format == "NCHW_VECT_C") {
905     // Use NCHW_VECT_C
906     minor_channel_index = 1;
907     y_index = 2;
908     x_index = 3;
909     major_channel_index = 4;
910   } else {
911     // Use NHWC.
912     y_index = 1;
913     x_index = 2;
914     major_channel_index = 3;
915   }
916   const std::string& filter_format = GetFilterFormat(op_info);
917   int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
918       in_minor_channel_index = -1;
919   if (filter_format == "HWIO") {
920     filter_y_index = 0;
921     filter_x_index = 1;
922     in_major_channel_index = 2;
923     out_channel_index = 3;
924   } else if (filter_format == "OIHW_VECT_I") {
925     out_channel_index = 0;
926     in_minor_channel_index = 1;
927     filter_y_index = 2;
928     filter_x_index = 3;
929     in_major_channel_index = 4;
930   } else {
931     // Use OIHW
932     out_channel_index = 0;
933     in_major_channel_index = 1;
934     filter_y_index = 2;
935     filter_x_index = 3;
936   }
937 
938   auto image_shape = MaybeGetMinimumShape(original_image_shape,
939                                           minor_channel_index >= 0 ? 5 : 4,
940                                           found_unknown_shapes);
941   auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
942                                            in_minor_channel_index >= 0 ? 5 : 4,
943                                            found_unknown_shapes);
944   VLOG(2) << "Image shape: " << image_shape.DebugString();
945   VLOG(2) << "Filter shape: " << filter_shape.DebugString();
946 
947   int64 batch = image_shape.dim(0).size();
948   int64 ix = image_shape.dim(x_index).size();
949   int64 iy = image_shape.dim(y_index).size();
950   int64 iz = minor_channel_index >= 0
951                  ? image_shape.dim(minor_channel_index).size() *
952                        image_shape.dim(major_channel_index).size()
953                  : image_shape.dim(major_channel_index).size();
954   int64 kx = filter_shape.dim(filter_x_index).size();
955   int64 ky = filter_shape.dim(filter_y_index).size();
956   int64 kz = in_minor_channel_index >= 0
957                  ? filter_shape.dim(in_major_channel_index).size() *
958                        filter_shape.dim(in_minor_channel_index).size()
959                  : filter_shape.dim(in_major_channel_index).size();
960   std::vector<int64> strides = GetStrides(op_info);
961   const auto padding = GetPadding(op_info);
962   int64 sx = strides[x_index];
963   int64 sy = strides[y_index];
964   int64 ox = GetOutputSize(ix, kx, sx, padding);
965   int64 oy = GetOutputSize(iy, ky, sy, padding);
966   int64 oz = filter_shape.dim(out_channel_index).size();
967   // Only check equality when both sizes are known (in other words, when
968   // neither is set to a minimum dimension size of 1).
969   if (iz != 1 && kz != 1) {
970     DCHECK_EQ(iz % kz, 0) << "Input channel " << iz
971                           << " is not a multiple of filter channel " << kz
972                           << ".";
973     if (iz % kz) {
974       *found_unknown_shapes = true;
975     }
976   } else {
977     iz = kz = std::max<int64>(iz, kz);
978   }
979   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
980       batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
981 
982   VLOG(1) << "Batch Size:" << batch;
983   VLOG(1) << "Image Dims:" << ix << "," << iy;
984   VLOG(1) << "Input Depth:" << iz;
985   VLOG(1) << "Kernel Dims:" << kx << "," << ky;
986   VLOG(1) << "Kernel Depth:" << kz;
987   VLOG(1) << "Output Dims:" << ox << "," << oy;
988   VLOG(1) << "Output Depth:" << oz;
989   VLOG(1) << "Strides:" << sx << "," << sy;
990   VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
991   return conv_dims;
992 }
993 
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes)994 int64 OpLevelCostEstimator::CountConv2DOperations(
995     const OpInfo& op_info, ConvolutionDimensions* conv_info,
996     bool* found_unknown_shapes) {
997   DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
998       << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
999 
1000   if (op_info.inputs_size() < 2) {  // Unexpect inputs.
1001     *found_unknown_shapes = true;
1002     return 0;
1003   }
1004 
1005   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1006       op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
1007       found_unknown_shapes);
1008 
1009   //  in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
1010   //  multiplier; The effective output channel depth oz_effective is
1011   //  conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
1012   //  Compare to Conv2D where # ops =  N x H x W x kz x oz x 2RS,
1013   //  oz = oz_effective,  then Conv2D_ops / Depthwise_conv2d_native_ops = kz.
1014   int64 ops = conv_dims.batch;
1015   ops *= conv_dims.ox * conv_dims.oy;
1016   ops *= conv_dims.kx * conv_dims.ky;
1017   if (op_info.op() == kConv2d) {
1018     ops *= conv_dims.kz * conv_dims.oz;
1019   } else {
1020     // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
1021     // although ops are the same as Conv2D.
1022     conv_dims.oz *= conv_dims.iz;
1023     ops *= conv_dims.oz;
1024   }
1025   ops *= kOpsPerMac;
1026 
1027   if (conv_info != nullptr) {
1028     *conv_info = conv_dims;
1029   }
1030   return ops;
1031 }
1032 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1033 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1034                                                   bool* found_unknown_shapes) {
1035   return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
1036 }
1037 
1038 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes)1039 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1040                                                   MatMulDimensions* mat_mul,
1041                                                   bool* found_unknown_shapes) {
1042   double ops = 0;
1043 
1044   if (op_info.inputs_size() < 2) {
1045     LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
1046     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1047     *found_unknown_shapes = true;
1048     return 0;
1049   }
1050 
1051   auto& a_matrix = op_info.inputs(0);
1052   auto& b_matrix = op_info.inputs(1);
1053 
1054   bool transpose_a = false;
1055   bool transpose_b = false;
1056 
1057   double m_dim, n_dim, k_dim, k_dim_b = 0;
1058 
1059   for (const auto& item : op_info.attr()) {
1060     VLOG(1) << "Key:" << item.first
1061             << " Value:" << SummarizeAttrValue(item.second);
1062     if (item.first == "transpose_a" && item.second.b() == true)
1063       transpose_a = true;
1064     if (item.first == "transpose_b" && item.second.b() == true)
1065       transpose_b = true;
1066   }
1067   VLOG(1) << "transpose_a:" << transpose_a;
1068   VLOG(1) << "transpose_b:" << transpose_b;
1069   auto a_matrix_shape =
1070       MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
1071   auto b_matrix_shape =
1072       MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
1073   if (transpose_a) {
1074     m_dim = a_matrix_shape.dim(1).size();
1075     k_dim = a_matrix_shape.dim(0).size();
1076   } else {
1077     m_dim = a_matrix_shape.dim(0).size();
1078     k_dim = a_matrix_shape.dim(1).size();
1079   }
1080   if (transpose_b) {
1081     k_dim_b = b_matrix_shape.dim(1).size();
1082     n_dim = b_matrix_shape.dim(0).size();
1083   } else {
1084     k_dim_b = b_matrix_shape.dim(0).size();
1085     n_dim = b_matrix_shape.dim(1).size();
1086   }
1087 
1088   VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
1089   // Only check equality when both sizes are known (in other words, when
1090   // neither is set to a minimum dimension size of 1).
1091   if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
1092     LOG(ERROR) << "Incompatible Matrix dimensions";
1093     return ops;
1094   } else {
1095     // One of k_dim and k_dim_b might be 1 (minimum dimension size).
1096     k_dim = std::max(k_dim, k_dim_b);
1097   }
1098 
1099   ops = m_dim * n_dim * k_dim * 2;
1100   VLOG(1) << "Operations for Matmul: " << ops;
1101 
1102   if (mat_mul != nullptr) {
1103     mat_mul->m = m_dim;
1104     mat_mul->n = n_dim;
1105     mat_mul->k = k_dim;
1106   }
1107   return ops;
1108 }
1109 
GenerateBatchMatmulContextFromEinsum(const OpContext & einsum_context,OpContext * batch_matmul_context,bool * found_unknown_shapes) const1110 bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
1111     const OpContext& einsum_context, OpContext* batch_matmul_context,
1112     bool* found_unknown_shapes) const {
1113   // This auxiliary function transforms an einsum OpContext into its equivalent
1114   // Batch Matmul OpContext. The function returns a boolean, which determines
1115   // whether it was successful in generating the output OpContext or not.
1116 
1117   // Einsum computes a generalized contraction between tensors of arbitrary
1118   // dimension as defined by the equation written in the Einstein summation
1119   // convention. The number of tensors in the computation and the number of
1120   // contractions can be arbitrarily long. The current model only contemplates
1121   // Einsum equations, which can be translated into a single BatchMatMul
1122   // operation. Einsum operations with more than two operands are not currently
1123   // supported. Subscripts where an axis appears more than once for a single
1124   // input and ellipsis are currently also excluded. See:
1125   // https://www.tensorflow.org/api_docs/python/tf/einsum
1126   // We distinguish four kinds of dimensions, depending on their placement in
1127   // the equation:
1128   // + B: Batch dimensions: Dimensions which appear in both operands and RHS.
1129   // + K: Contracting dimensions: These appear in both inputs but not RHS.
1130   // + M: Operand A dimensions: These appear in the first operand and the RHS.
1131   // + N: Operand B dimensions: These appear in the second operand and the RHS.
1132   // Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
1133 
1134   if (batch_matmul_context == nullptr) {
1135     VLOG(1) << "Output context should not be a nullptr.";
1136     return false;
1137   }
1138   if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
1139   const auto& op_info = einsum_context.op_info;
1140   std::vector<std::string> equation_split =
1141       absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
1142   std::vector<absl::string_view> input_split =
1143       absl::StrSplit(equation_split[0], ',');
1144   const auto& a_input = op_info.inputs(0);
1145   const auto& b_input = op_info.inputs(1);
1146   absl::string_view rhs_str = equation_split[1];
1147   absl::string_view a_input_str = input_split[0];
1148   absl::string_view b_input_str = input_split[1];
1149 
1150   constexpr int kMatrixRank = 2;
1151 
1152   bool a_input_shape_unknown = false;
1153   bool b_input_shape_unknown = false;
1154 
1155   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1156       a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
1157       &a_input_shape_unknown);
1158   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1159       b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
1160       &b_input_shape_unknown);
1161 
1162   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1163                           (a_input.shape().dim_size() < kMatrixRank) ||
1164                           (b_input.shape().dim_size() < kMatrixRank);
1165 
1166   OpInfo batch_matmul_op_info = op_info;
1167   batch_matmul_op_info.mutable_inputs()->Clear();
1168   batch_matmul_op_info.set_op("BatchMatMul");
1169 
1170   AttrValue transpose_attribute;
1171   transpose_attribute.set_b(false);
1172   (*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
1173   (*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
1174 
1175   OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
1176   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1177   a_matrix->set_dtype(a_input.dtype());
1178 
1179   OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
1180   b_matrix->set_dtype(b_input.dtype());
1181   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1182 
1183   TensorShapeProto_Dim m_dim;
1184   TensorShapeProto_Dim n_dim;
1185   TensorShapeProto_Dim k_dim;
1186 
1187   m_dim.set_size(1);
1188   n_dim.set_size(1);
1189   k_dim.set_size(1);
1190 
1191   for (int i_idx = 0, a_input_str_size = a_input_str.size();
1192        i_idx < a_input_str_size; ++i_idx) {
1193     if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
1194       if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1195         VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1196         return false;
1197       }
1198 
1199       m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
1200       continue;
1201     } else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1202       // The dimension does not appear in the RHS, therefore it is a contracting
1203       // dimension.
1204       k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
1205       continue;
1206     }
1207     // It appears in both input operands, therefore we place it as an outer
1208     // dimension for the Batch Matmul.
1209     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1210     *(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1211   }
1212   for (int i_idx = 0, b_input_str_size = b_input_str.size();
1213        i_idx < b_input_str_size; ++i_idx) {
1214     if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
1215       if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
1216         VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1217         return false;
1218       }
1219       n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
1220     }
1221   }
1222 
1223   // The two inner-most dimensions of the Batch Matmul are added.
1224   *(a_matrix_shape->add_dim()) = m_dim;
1225   *(a_matrix_shape->add_dim()) = k_dim;
1226   *(b_matrix_shape->add_dim()) = k_dim;
1227   *(b_matrix_shape->add_dim()) = n_dim;
1228 
1229   *batch_matmul_context = einsum_context;
1230   batch_matmul_context->op_info = batch_matmul_op_info;
1231   return true;
1232 }
1233 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1234 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1235     const OpInfo& op_info, bool* found_unknown_shapes) {
1236   return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
1237 }
1238 
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes)1239 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1240     const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul,
1241     bool* found_unknown_shapes) {
1242   if (op_info.op() != kBatchMatMul && op_info.op() != kBatchMatMulV2) {
1243     LOG(ERROR) << "Invalid Operation: " << op_info.op();
1244     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1245     *found_unknown_shapes = true;
1246     return 0;
1247   }
1248   if (op_info.inputs_size() != 2) {
1249     LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
1250     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1251     *found_unknown_shapes = true;
1252     return 0;
1253   }
1254 
1255   double ops = 0;
1256   const auto& a_input = op_info.inputs(0);
1257   const auto& b_input = op_info.inputs(1);
1258 
1259   // BatchMatMul requires inputs of at least matrix shape (rank 2).
1260   // The two most minor dimensions of each input are matrices that
1261   // need to be multiplied together. The other dimensions determine
1262   // the number of such MatMuls.  For example, if the BatchMatMul has
1263   // inputs of shape:
1264   //   a_input_shape = [2, 3, 4, 5]
1265   //   b_input_shape = [2, 3, 5, 6]
1266   // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
1267   // in this BatchMatMul.
1268   const int matrix_rank = 2;
1269 
1270   bool a_input_shape_unknown = false;
1271   bool b_input_shape_unknown = false;
1272 
1273   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1274       a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
1275       &a_input_shape_unknown);
1276   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1277       b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
1278       &b_input_shape_unknown);
1279 
1280   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1281                           (a_input.shape().dim_size() < matrix_rank) ||
1282                           (b_input.shape().dim_size() < matrix_rank);
1283 
1284   // Compute the number of matmuls as the max indicated at each dimension
1285   // by either input. Note that the shapes do not have to have
1286   // the same rank due to incompleteness.
1287   TensorShapeProto* bigger_rank_shape = &a_input_shape;
1288   TensorShapeProto* smaller_rank_shape = &b_input_shape;
1289   if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
1290     bigger_rank_shape = &b_input_shape;
1291     smaller_rank_shape = &a_input_shape;
1292   }
1293   int num_matmuls = 1;
1294   for (int b_i = 0,
1295            s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
1296        b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
1297     int b_dim = bigger_rank_shape->dim(b_i).size();
1298     int s_dim = 1;
1299     if (s_i >= 0) {
1300       s_dim = smaller_rank_shape->dim(s_i).size();
1301     }
1302     if (batch_mat_mul != nullptr) {
1303       batch_mat_mul->batch_dims.push_back(s_dim);
1304     }
1305     num_matmuls *= std::max(b_dim, s_dim);
1306   }
1307 
1308   // Build the MatMul. Note that values are ignored here since we are just
1309   // counting ops (e.g. only shapes matter).
1310   OpInfo matmul_op_info;
1311   matmul_op_info.set_op("MatMul");
1312 
1313   AttrValue transpose_a;
1314   transpose_a.set_b(false);
1315   if (op_info.attr().find("adj_x") != op_info.attr().end()) {
1316     transpose_a.set_b(op_info.attr().at("adj_x").b());
1317   }
1318   (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
1319 
1320   AttrValue transpose_b;
1321   transpose_b.set_b(false);
1322   if (op_info.attr().find("adj_y") != op_info.attr().end()) {
1323     transpose_b.set_b(op_info.attr().at("adj_y").b());
1324   }
1325   (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
1326 
1327   OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
1328   a_matrix->set_dtype(a_input.dtype());
1329   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1330   for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
1331        i < a_input_shape.dim_size(); ++i) {
1332     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
1333   }
1334 
1335   OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
1336   b_matrix->set_dtype(b_input.dtype());
1337   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1338   for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
1339        i < b_input_shape.dim_size(); ++i) {
1340     *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
1341   }
1342   if (batch_mat_mul != nullptr) {
1343     batch_mat_mul->matmul_dims.m = (transpose_a.b())
1344                                        ? a_matrix_shape->dim(1).size()
1345                                        : a_matrix_shape->dim(0).size();
1346     batch_mat_mul->matmul_dims.k = (transpose_a.b())
1347                                        ? a_matrix_shape->dim(0).size()
1348                                        : a_matrix_shape->dim(1).size();
1349     batch_mat_mul->matmul_dims.n = (transpose_b.b())
1350                                        ? b_matrix_shape->dim(0).size()
1351                                        : b_matrix_shape->dim(1).size();
1352   }
1353 
1354   for (int i = 0; i < num_matmuls; ++i) {
1355     bool matmul_unknown_shapes = false;
1356     ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
1357     *found_unknown_shapes |= matmul_unknown_shapes;
1358   }
1359   return ops;
1360 }
1361 
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)1362 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
1363                                         TensorShapeProto* tensor_shape_proto) {
1364   tensor_shape_proto->Clear();
1365   // First convert TensorProto into Tensor class so that it correctly parses
1366   // data values within TensorProto (whether it's in int_val, int64_val,
1367   // tensor_content, or anything.
1368   Tensor tensor(tensor_proto.dtype());
1369   if (!tensor.FromProto(tensor_proto)) {
1370     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1371                  << "failed to parse TensorProto: "
1372                  << tensor_proto.DebugString();
1373     return false;
1374   }
1375   if (tensor.dims() != 1) {
1376     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1377                  << "tensor is not 1D: " << tensor.dims();
1378     return false;
1379   }
1380   // Then, convert it back to TensorProto using AsProtoField, which makes sure
1381   // the data is in int_val, int64_val, or such repeated data fields, not in
1382   // tensor_content.
1383   TensorProto temp_tensor;
1384   tensor.AsProtoField(&temp_tensor);
1385 
1386 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type)        \
1387   do {                                                   \
1388     for (const auto& value : temp_tensor.type##_val()) { \
1389       tensor_shape_proto->add_dim()->set_size(value);    \
1390     }                                                    \
1391   } while (0)
1392 
1393   if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
1394       tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
1395     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
1396   } else if (tensor.dtype() == DT_INT64) {
1397     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
1398   } else if (tensor.dtype() == DT_UINT32) {
1399     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
1400   } else if (tensor.dtype() == DT_UINT64) {
1401     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
1402   } else {
1403     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1404                  << "Unsupported dtype: " << tensor.dtype();
1405     return false;
1406   }
1407 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
1408 
1409   return true;
1410 }
1411 
1412 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1413 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
1414     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1415     bool* found_unknown_shapes) {
1416   int64 ops = 0;
1417 
1418   DCHECK(op_info.op() == kConv2dBackpropInput ||
1419          op_info.op() == kDepthwiseConv2dNativeBackpropInput)
1420       << "Invalid Operation: not kConv2dBackpropInput nor"
1421          "kDepthwiseConv2dNativeBackpropInput";
1422 
1423   if (op_info.inputs_size() < 2) {
1424     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1425     *found_unknown_shapes = true;
1426     return ops;
1427   }
1428 
1429   TensorShapeProto input_shape;
1430   bool shape_found = false;
1431   if (op_info.inputs(0).has_value()) {
1432     const TensorProto& value = op_info.inputs(0).value();
1433     shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
1434   }
1435   if (!shape_found && op_info.outputs_size() == 1) {
1436     input_shape = op_info.outputs(0).shape();
1437     shape_found = true;
1438   }
1439   if (!shape_found) {
1440     // Set the minimum filter size that's feasible.
1441     input_shape.Clear();
1442     for (int i = 0; i < 4; ++i) {
1443       input_shape.add_dim()->set_size(1);
1444     }
1445     *found_unknown_shapes = true;
1446   }
1447 
1448   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1449       input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
1450 
1451   ops = conv_dims.batch;
1452   ops *= conv_dims.ox * conv_dims.oy;
1453   ops *= conv_dims.kx * conv_dims.ky;
1454   if (op_info.op() == kConv2dBackpropInput) {
1455     ops *= conv_dims.kz * conv_dims.oz;
1456   } else {
1457     // conv_dims always use forward path definition regardless
1458     conv_dims.oz *= conv_dims.iz;
1459     ops *= conv_dims.oz;
1460   }
1461   ops *= kOpsPerMac;
1462 
1463   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
1464 
1465   if (returned_conv_dims != nullptr) {
1466     *returned_conv_dims = conv_dims;
1467   }
1468   return ops;
1469 }
1470 
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1471 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
1472     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1473     bool* found_unknown_shapes) {
1474   int64 ops = 0;
1475 
1476   DCHECK(op_info.op() == kConv2dBackpropFilter ||
1477          op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
1478       << "Invalid Operation: not kConv2dBackpropFilter nor"
1479          "kDepthwiseConv2dNativeBackpropFilter";
1480 
1481   TensorShapeProto filter_shape;
1482   bool shape_found = false;
1483   if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
1484     const TensorProto& value = op_info.inputs(1).value();
1485     shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
1486   }
1487   if (!shape_found && op_info.outputs_size() == 1) {
1488     filter_shape = op_info.outputs(0).shape();
1489     shape_found = true;
1490   }
1491   if (!shape_found) {
1492     // Set the minimum filter size that's feasible.
1493     filter_shape.Clear();
1494     for (int i = 0; i < 4; ++i) {
1495       filter_shape.add_dim()->set_size(1);
1496     }
1497     *found_unknown_shapes = true;
1498   }
1499 
1500   if (op_info.inputs_size() < 1) {
1501     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1502     *found_unknown_shapes = true;
1503     return ops;
1504   }
1505   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1506       op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1507 
1508   ops = conv_dims.batch;
1509   ops *= conv_dims.ox * conv_dims.oy;
1510   ops *= conv_dims.kx * conv_dims.ky;
1511   if (op_info.op() == kConv2dBackpropFilter) {
1512     ops *= conv_dims.kz * conv_dims.oz;
1513   } else {
1514     // conv_dims always use forward path definition regardless
1515     conv_dims.oz *= conv_dims.iz;
1516     ops *= conv_dims.oz;
1517   }
1518   ops *= kOpsPerMac;
1519   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
1520 
1521   if (returned_conv_dims != nullptr) {
1522     *returned_conv_dims = conv_dims;
1523   }
1524   return ops;
1525 }
1526 
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1527 int64 OpLevelCostEstimator::CalculateTensorElementCount(
1528     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1529   VLOG(2) << "   with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1530           << tensor.shape().DebugString();
1531   int64 tensor_size = 1;
1532   int num_dims = std::max(1, tensor.shape().dim_size());
1533   auto tensor_shape =
1534       MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1535   for (const auto& dim : tensor_shape.dim()) {
1536     tensor_size *= dim.size();
1537   }
1538   return tensor_size;
1539 }
1540 
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1541 int64 OpLevelCostEstimator::CalculateTensorSize(
1542     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1543   int64 count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1544   int size = DataTypeSize(BaseType(tensor.dtype()));
1545   VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1546   return count * size;
1547 }
1548 
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes)1549 int64 OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
1550                                                bool* found_unknown_shapes) {
1551   int64 total_input_size = 0;
1552   for (auto& input : op_info.inputs()) {
1553     int64 input_size = CalculateTensorSize(input, found_unknown_shapes);
1554     total_input_size += input_size;
1555     VLOG(1) << "Input Size: " << input_size
1556             << " Total Input Size:" << total_input_size;
1557   }
1558   return total_input_size;
1559 }
1560 
CalculateInputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1561 std::vector<int64> OpLevelCostEstimator::CalculateInputTensorSize(
1562     const OpInfo& op_info, bool* found_unknown_shapes) {
1563   std::vector<int64> input_tensor_size;
1564   input_tensor_size.reserve(op_info.inputs().size());
1565   for (auto& input : op_info.inputs()) {
1566     input_tensor_size.push_back(
1567         CalculateTensorSize(input, found_unknown_shapes));
1568   }
1569   return input_tensor_size;
1570 }
1571 
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes)1572 int64 OpLevelCostEstimator::CalculateLargestInputCount(
1573     const OpInfo& op_info, bool* found_unknown_shapes) {
1574   int64 largest_input_count = 0;
1575   for (auto& input : op_info.inputs()) {
1576     int64 input_count =
1577         CalculateTensorElementCount(input, found_unknown_shapes);
1578     if (input_count > largest_input_count) {
1579       largest_input_count = input_count;
1580     }
1581     VLOG(1) << "Input Count: " << input_count
1582             << " Largest Input Count:" << largest_input_count;
1583   }
1584   return largest_input_count;
1585 }
1586 
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes)1587 int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
1588                                                 bool* found_unknown_shapes) {
1589   int64 total_output_size = 0;
1590   // Use float as default for calculations.
1591   for (const auto& output : op_info.outputs()) {
1592     DataType dt = output.dtype();
1593     const auto& original_output_shape = output.shape();
1594     int64 output_size = DataTypeSize(BaseType(dt));
1595     int num_dims = std::max(1, original_output_shape.dim_size());
1596     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1597                                              found_unknown_shapes);
1598     for (const auto& dim : output_shape.dim()) {
1599       output_size *= dim.size();
1600     }
1601     total_output_size += output_size;
1602     VLOG(1) << "Output Size: " << output_size
1603             << " Total Output Size:" << total_output_size;
1604   }
1605   return total_output_size;
1606 }
1607 
CalculateOutputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1608 std::vector<int64> OpLevelCostEstimator::CalculateOutputTensorSize(
1609     const OpInfo& op_info, bool* found_unknown_shapes) {
1610   std::vector<int64> output_tensor_size;
1611   output_tensor_size.reserve(op_info.outputs().size());
1612   // Use float as default for calculations.
1613   for (const auto& output : op_info.outputs()) {
1614     DataType dt = output.dtype();
1615     const auto& original_output_shape = output.shape();
1616     int64 output_size = DataTypeSize(BaseType(dt));
1617     int num_dims = std::max(1, original_output_shape.dim_size());
1618     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1619                                              found_unknown_shapes);
1620     for (const auto& dim : output_shape.dim()) {
1621       output_size *= dim.size();
1622     }
1623     output_tensor_size.push_back(output_size);
1624   }
1625   return output_tensor_size;
1626 }
1627 
PredictDefaultNodeCosts(const int64 num_compute_ops,const OpContext & op_context,bool * found_unknown_shapes,NodeCosts * node_costs)1628 Status OpLevelCostEstimator::PredictDefaultNodeCosts(
1629     const int64 num_compute_ops, const OpContext& op_context,
1630     bool* found_unknown_shapes, NodeCosts* node_costs) {
1631   const auto& op_info = op_context.op_info;
1632   node_costs->num_compute_ops = num_compute_ops;
1633   node_costs->num_input_bytes_accessed =
1634       CalculateInputTensorSize(op_info, found_unknown_shapes);
1635   node_costs->num_output_bytes_accessed =
1636       CalculateOutputTensorSize(op_info, found_unknown_shapes);
1637   node_costs->max_memory = node_costs->num_total_output_bytes();
1638   if (*found_unknown_shapes) {
1639     node_costs->inaccurate = true;
1640     node_costs->num_nodes_with_unknown_shapes = 1;
1641   }
1642   return Status::OK();
1643 }
1644 
HasZeroDim(const OpInfo & op_info)1645 bool HasZeroDim(const OpInfo& op_info) {
1646   for (int i = 0; i < op_info.inputs_size(); ++i) {
1647     const auto& input = op_info.inputs(i);
1648     for (int j = 0; j < input.shape().dim_size(); ++j) {
1649       const auto& dim = input.shape().dim(j);
1650       if (dim.size() == 0) {
1651         VLOG(1) << "Convolution config has zero dim "
1652                 << op_info.ShortDebugString();
1653         return true;
1654       }
1655     }
1656   }
1657   return false;
1658 }
1659 
PredictConv2D(const OpContext & op_context,NodeCosts * node_costs) const1660 Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
1661                                            NodeCosts* node_costs) const {
1662   const auto& op_info = op_context.op_info;
1663   if (HasZeroDim(op_info)) {
1664     node_costs->num_nodes_with_unknown_shapes = 1;
1665     return errors::InvalidArgument("Conv2D op includes zero dimension: ",
1666                                    op_info.ShortDebugString());
1667   }
1668   bool found_unknown_shapes = false;
1669   int64 num_compute_ops = CountConv2DOperations(op_info, &found_unknown_shapes);
1670   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1671                                  &found_unknown_shapes, node_costs);
1672 }
1673 
PredictConv2DBackpropInput(const OpContext & op_context,NodeCosts * node_costs) const1674 Status OpLevelCostEstimator::PredictConv2DBackpropInput(
1675     const OpContext& op_context, NodeCosts* node_costs) const {
1676   const auto& op_info = op_context.op_info;
1677   if (HasZeroDim(op_info)) {
1678     node_costs->num_nodes_with_unknown_shapes = 1;
1679     return errors::InvalidArgument(
1680         "Conv2DBackpropInput op includes zero dimension",
1681         op_info.ShortDebugString());
1682   }
1683   bool found_unknown_shapes = false;
1684   int64 num_compute_ops = CountConv2DBackpropInputOperations(
1685       op_info, nullptr, &found_unknown_shapes);
1686   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1687                                  &found_unknown_shapes, node_costs);
1688 }
1689 
PredictConv2DBackpropFilter(const OpContext & op_context,NodeCosts * node_costs) const1690 Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
1691     const OpContext& op_context, NodeCosts* node_costs) const {
1692   const auto& op_info = op_context.op_info;
1693   if (HasZeroDim(op_info)) {
1694     node_costs->num_nodes_with_unknown_shapes = 1;
1695     return errors::InvalidArgument(
1696         "Conv2DBackpropFilter op includes zero dimension",
1697         op_info.ShortDebugString());
1698   }
1699   bool found_unknown_shapes = false;
1700   int64 num_compute_ops = CountConv2DBackpropFilterOperations(
1701       op_info, nullptr, &found_unknown_shapes);
1702   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1703                                  &found_unknown_shapes, node_costs);
1704 }
1705 
PredictFusedConv2DBiasActivation(const OpContext & op_context,NodeCosts * node_costs) const1706 Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1707     const OpContext& op_context, NodeCosts* node_costs) const {
1708   // FusedConv2DBiasActivation computes a fused kernel which implements:
1709   // 2D convolution, adds side input with separate scaling on convolution and
1710   // side inputs, then adds bias, and finally applies the ReLU activation
1711   // function to the result:
1712   //
1713   // Input -> Conv2D  ->  Add  -> BiasAdd  -> ReLU
1714   //            ^          ^         ^
1715   //          Filter   Side Input   Bias
1716   //
1717   // Note that when adding the side input, the operation multiplies the output
1718   // of Conv2D by conv_input_scale, confusingly, and the side_input by
1719   // side_input_scale.
1720   //
1721   // Note that in the special case that side_input_scale is 0, which we infer
1722   // from side_input having dimensions [], we skip that addition operation.
1723   //
1724   // For more information, see
1725   // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1726 
1727   // TODO(yaozhang): Support NHWC_VECT_W.
1728   std::string data_format = GetDataFormat(op_context.op_info);
1729   if (data_format != "NCHW" && data_format != "NHWC" &&
1730       data_format != "NCHW_VECT_C") {
1731     return errors::InvalidArgument(
1732         "Unsupported data format (", data_format,
1733         ") for op: ", op_context.op_info.ShortDebugString());
1734   }
1735   std::string filter_format = GetFilterFormat(op_context.op_info);
1736   if (filter_format != "HWIO" && filter_format != "OIHW" &&
1737       filter_format != "OIHW_VECT_I") {
1738     return errors::InvalidArgument(
1739         "Unsupported filter format (", filter_format,
1740         ") for op: ", op_context.op_info.ShortDebugString());
1741   }
1742 
1743   auto& conv_input = op_context.op_info.inputs(0);
1744   auto& filter = op_context.op_info.inputs(1);
1745   auto& side_input = op_context.op_info.inputs(3);
1746   auto& conv_input_scale = op_context.op_info.inputs(4);
1747   auto& side_input_scale = op_context.op_info.inputs(5);
1748 
1749   // Manually compute our convolution dimensions.
1750   bool found_unknown_shapes = false;
1751   auto dims = ConvolutionDimensionsFromInputs(
1752       conv_input.shape(), filter.shape(), op_context.op_info,
1753       &found_unknown_shapes);
1754   OpInfo::TensorProperties output;
1755   if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
1756     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
1757   } else if (data_format == "NHWC") {
1758     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
1759   }
1760 
1761   // Add the operations the fused op always computes.
1762   std::vector<OpContext> component_ops = {
1763       FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1764       FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1765       FusedChildContext(
1766           op_context, "BiasAdd", output,
1767           {output, output}),  // Note we're no longer using bias at all
1768       FusedChildContext(op_context, "Relu", output, {output})};
1769 
1770   // Add our side_input iff it's non-empty.
1771   if (side_input.shape().dim_size() > 0) {
1772     component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1773                                               {side_input, side_input_scale}));
1774     component_ops.push_back(FusedChildContext(
1775         op_context, "Add", output,
1776         {output, output}));  // Note that we're not using side_input here
1777   }
1778 
1779   // Construct an op_context which definitely has our output shape.
1780   auto op_context_with_output = op_context;
1781   op_context_with_output.op_info.mutable_outputs()->Clear();
1782   *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1783 
1784   // Construct component operations and run the cost computation.
1785   if (found_unknown_shapes) {
1786     node_costs->inaccurate = true;
1787     node_costs->num_nodes_with_unknown_shapes = 1;
1788   }
1789   return PredictFusedOp(op_context_with_output, component_ops, node_costs);
1790 }
1791 
PredictMatMul(const OpContext & op_context,NodeCosts * node_costs) const1792 Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
1793                                            NodeCosts* node_costs) const {
1794   const auto& op_info = op_context.op_info;
1795   bool found_unknown_shapes = false;
1796   int64 num_compute_ops = CountMatMulOperations(op_info, &found_unknown_shapes);
1797   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1798                                  &found_unknown_shapes, node_costs);
1799 }
1800 
PredictEinsum(const OpContext & op_context,NodeCosts * node_costs) const1801 Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
1802                                            NodeCosts* node_costs) const {
1803   const auto& op_info = op_context.op_info;
1804 
1805   auto it = op_info.attr().find("equation");
1806   if (it == op_info.attr().end()) {
1807     return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
1808                                    op_info.ShortDebugString());
1809   }
1810 
1811   OpContext batch_matmul_op_context;
1812   bool found_unknown_shapes = false;
1813   bool success = GenerateBatchMatmulContextFromEinsum(
1814       op_context, &batch_matmul_op_context, &found_unknown_shapes);
1815   if (found_unknown_shapes) {
1816     node_costs->inaccurate = true;
1817     node_costs->num_nodes_with_unknown_shapes = 1;
1818   }
1819   if (!success) {
1820     return PredictCostOfAnUnknownOp(op_context, node_costs);
1821   }
1822   return PredictNodeCosts(batch_matmul_op_context, node_costs);
1823 }
1824 
PredictSparseTensorDenseMatMul(const OpContext & op_context,NodeCosts * node_costs) const1825 Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1826     const OpContext& op_context, NodeCosts* node_costs) const {
1827   const auto& op_info = op_context.op_info;
1828   bool found_unknown_shapes = false;
1829   // input[0]: indices in sparse matrix a
1830   // input[1]: values in sparse matrix a
1831   // input[2]: shape of matrix a
1832   // input[3]: matrix b
1833   // See
1834   // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1835   int64 num_elems_in_a =
1836       CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1837   auto b_matrix = op_info.inputs(3);
1838   auto b_matrix_shape =
1839       MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1840   int64 n_dim = b_matrix_shape.dim(1).size();
1841 
1842   // Each element in A is multiplied and added with an element from each column
1843   // in b.
1844   const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
1845 
1846   int64 a_indices_input_size =
1847       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1848   int64 a_values_input_size =
1849       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1850   int64 a_shape_input_size =
1851       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1852   int64 b_input_size =
1853       num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1854   int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1855 
1856   node_costs->num_compute_ops = op_count;
1857   node_costs->num_input_bytes_accessed = {a_indices_input_size,
1858                                           a_values_input_size,
1859                                           a_shape_input_size, b_input_size};
1860   node_costs->num_output_bytes_accessed = {output_size};
1861   if (found_unknown_shapes) {
1862     node_costs->inaccurate = true;
1863     node_costs->num_nodes_with_unknown_shapes = 1;
1864   }
1865   return Status::OK();
1866 }
1867 
PredictNoOp(const OpContext & op_context,NodeCosts * node_costs) const1868 Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
1869                                          NodeCosts* node_costs) const {
1870   const auto& op_info = op_context.op_info;
1871   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1872   // By default, NodeCosts is initialized to zero ops and bytes.
1873   return Status::OK();
1874 }
1875 
PredictPureMemoryOp(const OpContext & op_context,NodeCosts * node_costs) const1876 Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
1877                                                  NodeCosts* node_costs) const {
1878   // Each output element is a copy of some element from input, with no required
1879   // computation, so just compute memory costs.
1880   bool found_unknown_shapes = false;
1881   node_costs->num_nodes_with_pure_memory_op = 1;
1882   return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
1883                                  node_costs);
1884 }
1885 
PredictIdentity(const OpContext & op_context,NodeCosts * node_costs) const1886 Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
1887                                              NodeCosts* node_costs) const {
1888   const auto& op_info = op_context.op_info;
1889   VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
1890   node_costs->minimum_cost_op = true;
1891   node_costs->num_compute_ops = kMinComputeOp;
1892   // Identity op internally pass input tensor buffer's pointer to the output
1893   // tensor buffer; no actual memory operation.
1894   node_costs->num_input_bytes_accessed = {0};
1895   node_costs->num_output_bytes_accessed = {0};
1896   bool inaccurate = false;
1897   node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1898   if (inaccurate) {
1899     node_costs->inaccurate = true;
1900     node_costs->num_nodes_with_unknown_shapes = 1;
1901   }
1902   return Status::OK();
1903 }
1904 
PredictVariable(const OpContext & op_context,NodeCosts * node_costs) const1905 Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
1906                                              NodeCosts* node_costs) const {
1907   const auto& op_info = op_context.op_info;
1908   VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
1909   node_costs->minimum_cost_op = true;
1910   node_costs->num_compute_ops = kMinComputeOp;
1911   // Variables are persistent ops; initialized before step; hence, no memory
1912   // cost.
1913   node_costs->num_input_bytes_accessed = {0};
1914   node_costs->num_output_bytes_accessed = {0};
1915   bool inaccurate = false;
1916   node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
1917   if (inaccurate) {
1918     node_costs->inaccurate = true;
1919     node_costs->num_nodes_with_unknown_shapes = 1;
1920   }
1921   return Status::OK();
1922 }
1923 
PredictBatchMatMul(const OpContext & op_context,NodeCosts * node_costs) const1924 Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
1925                                                 NodeCosts* node_costs) const {
1926   const auto& op_info = op_context.op_info;
1927   bool found_unknown_shapes = false;
1928   int64 num_compute_ops =
1929       CountBatchMatMulOperations(op_info, &found_unknown_shapes);
1930   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1931                                  &found_unknown_shapes, node_costs);
1932 }
1933 
PredictMetadata(const OpContext & op_context,NodeCosts * node_costs) const1934 Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
1935                                              NodeCosts* node_costs) const {
1936   const auto& op_info = op_context.op_info;
1937   node_costs->minimum_cost_op = true;
1938   node_costs->num_compute_ops = kMinComputeOp;
1939   node_costs->num_input_bytes_accessed = {0};
1940   node_costs->num_output_bytes_accessed = {0};
1941   bool inaccurate = false;
1942   node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1943   if (inaccurate) {
1944     node_costs->inaccurate = true;
1945     node_costs->num_nodes_with_unknown_shapes = 1;
1946   }
1947   return Status::OK();
1948 }
1949 
PredictGatherOrSlice(const OpContext & op_context,NodeCosts * node_costs) const1950 Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
1951                                                   NodeCosts* node_costs) const {
1952   // Gather & Slice ops can have a very large input, but only access a small
1953   // part of it. For these op the size of the output determines the memory cost.
1954   const auto& op_info = op_context.op_info;
1955 
1956   const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1957   if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1958     return errors::InvalidArgument(
1959         op_info.op(),
1960         " Op doesn't have valid input / output: ", op_info.ShortDebugString());
1961   }
1962 
1963   bool unknown_shapes = false;
1964 
1965   // Each output element is a copy of some element from input.
1966   // For roofline estimate we assume each copy has a unit cost.
1967   const int64 op_count =
1968       CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
1969   node_costs->num_compute_ops = op_count;
1970 
1971   const int64 output_size = CalculateOutputSize(op_info, &unknown_shapes);
1972   node_costs->num_output_bytes_accessed = {output_size};
1973 
1974   node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
1975   int64 input_size = output_size;
1976   // Note that input(0) byte accessed is not equal to input(0) tensor size.
1977   // It's equal to the output size; though, input access is indexed gather or
1978   // slice (ignore duplicate indices).
1979   node_costs->num_input_bytes_accessed.push_back(input_size);
1980   int begin_input_index = 1;
1981   int end_input_index;
1982   if (op_info.op() == "Slice") {
1983     // Slice: 'input' (omitted), 'begin', 'size'
1984     end_input_index = 3;
1985   } else if (op_info.op() == "StridedSlice") {
1986     // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
1987     end_input_index = 4;
1988   } else {
1989     // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
1990     end_input_index = 2;
1991   }
1992   for (int i = begin_input_index; i < end_input_index; ++i) {
1993     node_costs->num_input_bytes_accessed.push_back(
1994         CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
1995   }
1996   if (unknown_shapes) {
1997     node_costs->inaccurate = true;
1998     node_costs->num_nodes_with_unknown_shapes = 1;
1999   }
2000   return Status::OK();
2001 }
2002 
PredictScatter(const OpContext & op_context,NodeCosts * node_costs) const2003 Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
2004                                             NodeCosts* node_costs) const {
2005   // Scatter ops sparsely access a reference input and output tensor.
2006   const auto& op_info = op_context.op_info;
2007   bool found_unknown_shapes = false;
2008 
2009   // input[0]: ref tensor that will be sparsely accessed
2010   // input[1]: indices - A tensor of indices into the first dimension of ref.
2011   // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
2012   // See
2013   // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
2014   // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
2015 
2016   const int64 num_indices =
2017       CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
2018 
2019   int64 num_elems_in_ref_per_index = 1;
2020   auto ref_tensor_shape = MaybeGetMinimumShape(
2021       op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
2022       &found_unknown_shapes);
2023   for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
2024     num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
2025   }
2026   const int64 op_count = num_indices * num_elems_in_ref_per_index;
2027   node_costs->num_compute_ops = op_count;
2028 
2029   // Sparsely access ref so input size depends on the number of operations
2030   int64 ref_input_size =
2031       op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2032   int64 indices_input_size =
2033       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2034   int64 updates_input_size =
2035       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2036   node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
2037                                           updates_input_size};
2038 
2039   // Sparsely access ref so output size depends on the number of operations
2040   int64 output_size =
2041       op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
2042   node_costs->num_output_bytes_accessed = {output_size};
2043 
2044   if (found_unknown_shapes) {
2045     node_costs->inaccurate = true;
2046     node_costs->num_nodes_with_unknown_shapes = 1;
2047   }
2048   return Status::OK();
2049 }
2050 
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts,NodeCosts * node_costs) const2051 Status OpLevelCostEstimator::PredictFusedOp(
2052     const OpContext& op_context,
2053     const std::vector<OpContext>& fused_op_contexts,
2054     NodeCosts* node_costs) const {
2055   // Note that PredictDefaultNodeCosts will get the correct memory costs from
2056   // the node's inputs and outputs; but we don't want to have to re-implement
2057   // the logic for computing the operation count of each of our component
2058   // operations here; so we simply add the compute times of each component
2059   // operation, then update the cost.
2060   bool found_unknown_shapes = false;
2061   Status s =
2062       PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
2063 
2064   for (auto& fused_op : fused_op_contexts) {
2065     NodeCosts fused_node_costs;
2066     s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
2067     node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
2068     node_costs->inaccurate |= fused_node_costs.inaccurate;
2069     // Set, not increment. Note that we are predicting the cost of one fused
2070     // node, not a function node composed of many nodes.
2071     node_costs->num_nodes_with_unknown_shapes |=
2072         fused_node_costs.num_nodes_with_unknown_shapes;
2073     node_costs->num_nodes_with_unknown_op_type |=
2074         fused_node_costs.num_nodes_with_unknown_op_type;
2075     node_costs->num_nodes_with_pure_memory_op |=
2076         fused_node_costs.num_nodes_with_pure_memory_op;
2077   }
2078 
2079   return Status::OK();
2080 }
2081 
2082 /* static */
FusedChildContext(const OpContext & parent,const std::string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)2083 OpContext OpLevelCostEstimator::FusedChildContext(
2084     const OpContext& parent, const std::string& op_name,
2085     const OpInfo::TensorProperties& output,
2086     const std::vector<OpInfo::TensorProperties>& inputs) {
2087   // Setup the base parameters of our new context.
2088   OpContext new_context;
2089   new_context.name = op_name;
2090   new_context.device_name = parent.device_name;
2091   new_context.op_info = parent.op_info;
2092   new_context.op_info.set_op(op_name);
2093 
2094   // Setup the inputs of our new context.
2095   new_context.op_info.mutable_inputs()->Clear();
2096   for (const auto& input : inputs) {
2097     *new_context.op_info.mutable_inputs()->Add() = input;
2098   }
2099 
2100   // Setup the output of our new context.
2101   new_context.op_info.mutable_outputs()->Clear();
2102   *new_context.op_info.mutable_outputs()->Add() = output;
2103 
2104   return new_context;
2105 }
2106 
2107 /* static */
DescribeTensor(DataType type,const std::vector<int64> & dims)2108 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
2109     DataType type, const std::vector<int64>& dims) {
2110   OpInfo::TensorProperties ret;
2111   ret.set_dtype(type);
2112 
2113   auto shape = ret.mutable_shape();
2114   for (const int dim : dims) {
2115     shape->add_dim()->set_size(dim);
2116   }
2117 
2118   return ret;
2119 }
2120 
2121 /* static */
2122 OpLevelCostEstimator::ConvolutionDimensions
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)2123 OpLevelCostEstimator::OpDimensionsFromInputs(
2124     const TensorShapeProto& original_image_shape, const OpInfo& op_info,
2125     bool* found_unknown_shapes) {
2126   VLOG(2) << "op features: " << op_info.DebugString();
2127   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
2128   auto image_shape =
2129       MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
2130   VLOG(2) << "Image shape: " << image_shape.DebugString();
2131 
2132   int x_index, y_index, channel_index;
2133   const std::string& data_format = GetDataFormat(op_info);
2134   if (data_format == "NCHW") {
2135     channel_index = 1;
2136     y_index = 2;
2137     x_index = 3;
2138   } else {
2139     y_index = 1;
2140     x_index = 2;
2141     channel_index = 3;
2142   }
2143   int64 batch = image_shape.dim(0).size();
2144   int64 ix = image_shape.dim(x_index).size();
2145   int64 iy = image_shape.dim(y_index).size();
2146   int64 iz = image_shape.dim(channel_index).size();
2147 
2148   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
2149   // {1, 1, 1, 1} in that case.
2150   std::vector<int64> ksize = GetKernelSize(op_info);
2151   int64 kx = ksize[x_index];
2152   int64 ky = ksize[y_index];
2153   // These ops don't support groupwise operation, therefore kz == iz.
2154   int64 kz = iz;
2155 
2156   std::vector<int64> strides = GetStrides(op_info);
2157   int64 sx = strides[x_index];
2158   int64 sy = strides[y_index];
2159   const auto padding = GetPadding(op_info);
2160 
2161   int64 ox = GetOutputSize(ix, kx, sx, padding);
2162   int64 oy = GetOutputSize(iy, ky, sy, padding);
2163   int64 oz = iz;
2164 
2165   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
2166       batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
2167   return conv_dims;
2168 }
2169 
PredictMaxPool(const OpContext & op_context,NodeCosts * node_costs) const2170 Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
2171                                             NodeCosts* node_costs) const {
2172   bool found_unknown_shapes = false;
2173   const auto& op_info = op_context.op_info;
2174   // x: op_info.inputs(0)
2175   ConvolutionDimensions dims = OpDimensionsFromInputs(
2176       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2177   // kx * ky - 1 comparisons per output (kx * xy > 1)
2178   // or 1 copy per output (kx * k1 = 1).
2179   int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
2180   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
2181   node_costs->num_compute_ops = ops;
2182 
2183   int64 input_size = 0;
2184   if (dims.ky >= dims.sy) {
2185     input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2186   } else {  // dims.ky < dims.sy
2187     // Vertical stride is larger than vertical kernel; assuming row-major
2188     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2189     // others are not used for output).
2190     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2191     input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2192   }
2193   node_costs->num_input_bytes_accessed = {input_size};
2194   const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
2195   node_costs->num_output_bytes_accessed = {output_size};
2196   node_costs->max_memory = output_size;
2197   if (found_unknown_shapes) {
2198     node_costs->inaccurate = true;
2199     node_costs->num_nodes_with_unknown_shapes = 1;
2200   }
2201   return Status::OK();
2202 }
2203 
PredictMaxPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2204 Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
2205                                                 NodeCosts* node_costs) const {
2206   bool found_unknown_shapes = false;
2207   const auto& op_info = op_context.op_info;
2208   // x: op_info.inputs(0)
2209   // y: op_info.inputs(1)
2210   // y_grad: op_info.inputs(2)
2211   if (op_info.inputs_size() < 3) {
2212     return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
2213                                    op_info.ShortDebugString());
2214   }
2215 
2216   ConvolutionDimensions dims = OpDimensionsFromInputs(
2217       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2218 
2219   int64 ops = 0;
2220   if (dims.kx == 1 && dims.ky == 1) {
2221     // 1x1 window. No need to know which input was max.
2222     ops = dims.batch * dims.ix * dims.iy * dims.iz;
2223   } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2224     // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
2225     ops = dims.batch * dims.iz *
2226           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
2227   } else {
2228     // Overlapping window: initialize with zeros, re-run maxpool, then
2229     // accumulate y_gad to proper x_grad locations.
2230     ops = dims.batch * dims.iz *
2231           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
2232   }
2233   node_costs->num_compute_ops = ops;
2234 
2235   // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
2236   // MaxPool internally.
2237   const int64 input0_size =
2238       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2239   const int64 input2_size =
2240       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2241   node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
2242   // Write x_grad; size equal to x.
2243   const int64 output_size =
2244       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2245   node_costs->num_output_bytes_accessed = {output_size};
2246   node_costs->max_memory = output_size;
2247 
2248   if (found_unknown_shapes) {
2249     node_costs->inaccurate = true;
2250     node_costs->num_nodes_with_unknown_shapes = 1;
2251   }
2252   return Status::OK();
2253 }
2254 
2255 /* This predict function handles three types of tensorflow ops
2256  * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
2257  * was not possible for these ops, therefore the input tensor's shapes is
2258  * enough to compute the cost */
PredictAssignVariableOps(const OpContext & op_context,NodeCosts * node_costs) const2259 Status OpLevelCostEstimator::PredictAssignVariableOps(
2260     const OpContext& op_context, NodeCosts* node_costs) const {
2261   bool found_unknown_shapes = false;
2262   const auto& op_info = op_context.op_info;
2263   /* First input of these ops are reference to the assignee. */
2264   if (op_info.inputs_size() != 2) {
2265     return errors::InvalidArgument("AssignVariable op has invalid input: ",
2266                                    op_info.ShortDebugString());
2267   }
2268 
2269   const int64 ops = op_info.op() == kAssignVariableOp
2270                         ? 0
2271                         : CalculateTensorElementCount(op_info.inputs(1),
2272                                                       &found_unknown_shapes);
2273   node_costs->num_compute_ops = ops;
2274   const int64 input_size = CalculateInputSize(op_info, &found_unknown_shapes);
2275   node_costs->num_input_bytes_accessed = {input_size};
2276   // TODO(dyoon): check these ops' behavior whether it writes data;
2277   // Op itself doesn't have output tensor, but it may modify the input (ref or
2278   // resource). Maybe use node_costs->internal_write_bytes.
2279   node_costs->num_output_bytes_accessed = {0};
2280   if (found_unknown_shapes) {
2281     node_costs->inaccurate = true;
2282     node_costs->num_nodes_with_unknown_shapes = 1;
2283   }
2284   return Status::OK();
2285 }
2286 
PredictAvgPool(const OpContext & op_context,NodeCosts * node_costs) const2287 Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
2288                                             NodeCosts* node_costs) const {
2289   bool found_unknown_shapes = false;
2290   const auto& op_info = op_context.op_info;
2291   // x: op_info.inputs(0)
2292   ConvolutionDimensions dims = OpDimensionsFromInputs(
2293       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2294 
2295   // kx * ky - 1 additions and 1 multiplication per output.
2296   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
2297   node_costs->num_compute_ops = ops;
2298 
2299   int64 input_size;
2300   if (dims.ky >= dims.sy) {
2301     input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2302   } else {  // dims.ky < dims.sy
2303     // vertical stride is larger than vertical kernel; assuming row-major
2304     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2305     // others are not used for output).
2306     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2307     input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2308   }
2309   node_costs->num_input_bytes_accessed = {input_size};
2310 
2311   const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
2312   node_costs->num_output_bytes_accessed = {output_size};
2313   node_costs->max_memory = output_size;
2314 
2315   if (found_unknown_shapes) {
2316     node_costs->inaccurate = true;
2317     node_costs->num_nodes_with_unknown_shapes = 1;
2318   }
2319   return Status::OK();
2320 }
2321 
PredictAvgPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2322 Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
2323                                                 NodeCosts* node_costs) const {
2324   bool found_unknown_shapes = false;
2325   const auto& op_info = op_context.op_info;
2326   // x's shape: op_info.inputs(0)
2327   // y_grad: op_info.inputs(1)
2328 
2329   // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
2330   bool shape_found = false;
2331   TensorShapeProto x_shape;
2332   if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
2333     const TensorProto& value = op_info.inputs(0).value();
2334     shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
2335   }
2336   if (!shape_found && op_info.outputs_size() > 0) {
2337     x_shape = op_info.outputs(0).shape();
2338     shape_found = true;
2339   }
2340   if (!shape_found) {
2341     // Set the minimum shape that's feasible.
2342     x_shape.Clear();
2343     for (int i = 0; i < 4; ++i) {
2344       x_shape.add_dim()->set_size(1);
2345     }
2346     found_unknown_shapes = true;
2347   }
2348 
2349   ConvolutionDimensions dims =
2350       OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
2351 
2352   int64 ops = 0;
2353   if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2354     // Non-overlapping window.
2355     ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
2356   } else {
2357     // Overlapping window.
2358     ops = dims.batch * dims.iz *
2359           (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
2360   }
2361   auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2362                                    node_costs);
2363   node_costs->max_memory = node_costs->num_total_output_bytes();
2364   return s;
2365 }
2366 
PredictFusedBatchNorm(const OpContext & op_context,NodeCosts * node_costs) const2367 Status OpLevelCostEstimator::PredictFusedBatchNorm(
2368     const OpContext& op_context, NodeCosts* node_costs) const {
2369   bool found_unknown_shapes = false;
2370   const auto& op_info = op_context.op_info;
2371   // x: op_info.inputs(0)
2372   // scale: op_info.inputs(1)
2373   // offset: op_info.inputs(2)
2374   // mean: op_info.inputs(3)  --> only for inference
2375   // variance: op_info.inputs(4) --> only for inference
2376   ConvolutionDimensions dims = OpDimensionsFromInputs(
2377       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2378   const bool is_training = IsTraining(op_info);
2379 
2380   int64 ops = 0;
2381   const auto rsqrt_cost = Eigen::internal::functor_traits<
2382       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2383   if (is_training) {
2384     ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
2385   } else {
2386     ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
2387   }
2388   node_costs->num_compute_ops = ops;
2389 
2390   const int64 size_nhwc =
2391       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2392   const int64 size_c =
2393       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2394   if (is_training) {
2395     node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
2396     node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2397                                              size_c};
2398     // FusedBatchNorm in training mode internally re-reads the input tensor:
2399     // one for mean/variance, and the 2nd internal read forthe actual scaling.
2400     // Assume small intermediate data such as mean / variance (size_c) can be
2401     // cached on-chip.
2402     node_costs->internal_read_bytes = size_nhwc;
2403   } else {
2404     node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2405                                             size_c};
2406     node_costs->num_output_bytes_accessed = {size_nhwc};
2407   }
2408   node_costs->max_memory = node_costs->num_total_output_bytes();
2409 
2410   if (found_unknown_shapes) {
2411     node_costs->inaccurate = true;
2412     node_costs->num_nodes_with_unknown_shapes = 1;
2413   }
2414   return Status::OK();
2415 }
2416 
PredictFusedBatchNormGrad(const OpContext & op_context,NodeCosts * node_costs) const2417 Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
2418     const OpContext& op_context, NodeCosts* node_costs) const {
2419   bool found_unknown_shapes = false;
2420   const auto& op_info = op_context.op_info;
2421   // y_backprop: op_info.inputs(0)
2422   // x: op_info.inputs(1)
2423   // scale: op_info.inputs(2)
2424   // mean: op_info.inputs(3)
2425   // variance or inverse of variance: op_info.inputs(4)
2426   ConvolutionDimensions dims = OpDimensionsFromInputs(
2427       op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
2428 
2429   int64 ops = 0;
2430   const auto rsqrt_cost = Eigen::internal::functor_traits<
2431       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2432   ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
2433   node_costs->num_compute_ops = ops;
2434 
2435   const int64 size_nhwc =
2436       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2437   const int64 size_c =
2438       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2439   // TODO(dyoon): fix missing memory cost for variance input (size_c) and
2440   // yet another read of y_backprop (size_nhwc) internally.
2441   node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
2442   node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
2443   // FusedBatchNormGrad has to read y_backprop internally.
2444   node_costs->internal_read_bytes = size_nhwc;
2445   node_costs->max_memory = node_costs->num_total_output_bytes();
2446 
2447   if (found_unknown_shapes) {
2448     node_costs->inaccurate = true;
2449     node_costs->num_nodes_with_unknown_shapes = 1;
2450   }
2451   return Status::OK();
2452 }
2453 
PredictNaryOp(const OpContext & op_context,NodeCosts * node_costs) const2454 Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
2455                                            NodeCosts* node_costs) const {
2456   const auto& op_info = op_context.op_info;
2457   bool found_unknown_shapes = false;
2458   // Calculate the largest known tensor size across all inputs and output.
2459   int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
2460   // If output shape is available, try to use the element count calculated from
2461   // that.
2462   if (op_info.outputs_size() > 0) {
2463     op_count = std::max(
2464         op_count,
2465         CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
2466   }
2467   // Also calculate the output shape possibly resulting from broadcasting.
2468   // Note that the some Nary ops (such as AddN) do not support broadcasting,
2469   // but we're including this here for completeness.
2470   if (op_info.inputs_size() >= 2) {
2471     op_count = std::max(op_count, CwiseOutputElementCount(op_info));
2472   }
2473 
2474   // Nary ops perform one operation for every element in every input tensor.
2475   op_count *= op_info.inputs_size() - 1;
2476 
2477   const auto sum_cost = Eigen::internal::functor_traits<
2478       Eigen::internal::scalar_sum_op<float>>::Cost;
2479   return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
2480                                  &found_unknown_shapes, node_costs);
2481 }
2482 
2483 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
PredictSoftmax(const OpContext & op_context,NodeCosts * node_costs) const2484 Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
2485                                             NodeCosts* node_costs) const {
2486   bool found_unknown_shapes = false;
2487   const int64 logits_size = CalculateTensorElementCount(
2488       op_context.op_info.inputs(0), &found_unknown_shapes);
2489   // Softmax input rank should be >=1.
2490   TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
2491   if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
2492     return errors::InvalidArgument("Softmax op has invalid input: ",
2493                                    op_context.op_info.ShortDebugString());
2494   }
2495 
2496 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2497 
2498   // Every element of <logits> will be exponentiated, have that result included
2499   // in a sum across j, and also have that result multiplied by the reciprocal
2500   // of the sum_j. In addition, we'll compute 1/sum_j for every i.
2501   auto ops =
2502       (EIGEN_COST(scalar_exp_op<float>) + EIGEN_COST(scalar_sum_op<float>) +
2503        EIGEN_COST(scalar_product_op<float>)) *
2504           logits_size +
2505       EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
2506 
2507 #undef EIGEN_COST
2508   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2509                                  node_costs);
2510 }
2511 
PredictResizeBilinear(const OpContext & op_context,NodeCosts * node_costs) const2512 Status OpLevelCostEstimator::PredictResizeBilinear(
2513     const OpContext& op_context, NodeCosts* node_costs) const {
2514   bool found_unknown_shapes = false;
2515 
2516   if (op_context.op_info.outputs().empty() ||
2517       op_context.op_info.inputs().empty()) {
2518     return errors::InvalidArgument(
2519         "ResizeBilinear op has invalid input / output ",
2520         op_context.op_info.ShortDebugString());
2521   }
2522 
2523   const int64 output_elements = CalculateTensorElementCount(
2524       op_context.op_info.outputs(0), &found_unknown_shapes);
2525 
2526   const auto half_pixel_centers =
2527       op_context.op_info.attr().find("half_pixel_centers");
2528   bool use_half_pixel_centers = false;
2529   if (half_pixel_centers == op_context.op_info.attr().end()) {
2530     LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
2531     return PredictCostOfAnUnknownOp(op_context, node_costs);
2532   } else {
2533     use_half_pixel_centers = half_pixel_centers->second.b();
2534   }
2535 
2536   // Compose cost of bilinear interpolation.
2537   int64 ops = 0;
2538 
2539 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2540   const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
2541   const auto sub_cost_int = EIGEN_COST(scalar_difference_op<int64>);
2542   const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2543   const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2544   const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2545   const auto max_cost = EIGEN_COST(scalar_max_op<int64>);
2546   const auto min_cost = EIGEN_COST(scalar_min_op<int64>);
2547   const auto cast_to_int_cost = Eigen::internal::functor_traits<
2548       Eigen::internal::scalar_cast_op<float, int64>>::Cost;
2549   const auto cast_to_float_cost = Eigen::internal::functor_traits<
2550       Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2551   const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2552 #undef EIGEN_COST
2553 
2554   // Ops calcualted from tensorflow/core/kernels/image/resize_bilinear_op.cc.
2555 
2556   // Op counts taken from resize_bilinear implementation on 07/21/2020.
2557   // Computed op counts may become inaccurate if resize_bilinear implementation
2558   // changes.
2559 
2560   // resize_bilinear has an optimization where the interpolation weights are
2561   // precomputed and cached. Given input tensors of size [B,H1,W1,C] and output
2562   // tensors of size [B,H2,W2,C], the last dimension C that needs to be accessed
2563   // in the input for interpolation are identical at every point in the output.
2564   // These values are cached in the compute_interpolation_weights function. For
2565   // a particular y in [0...H2-1], the rows to be accessed in the input are the
2566   // same. Likewise, for a particular x in [0...H2-1], the columns to be accsed
2567   // are the same. So the precomputation only needs to be done for H2 + W2
2568   // values.
2569   const auto output_shape = MaybeGetMinimumShape(
2570       op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2571   // Assume H is dim 1 and W is dim 2 to match logic in resize_bilinear, which
2572   // also makes this assumption.
2573   const int64 output_height = output_shape.dim(1).size();
2574   const int64 output_width = output_shape.dim(2).size();
2575   // Add the ops done outside of the scaler function in
2576   // compute_interpolation_weights.
2577   int64 interp_weight_cost = floor_cost + max_cost + min_cost + sub_cost_float +
2578                              sub_cost_int + ceil_cost + cast_to_int_cost * 2;
2579   // There are two options for computing the weight of each pixel in the
2580   // interpolation. Algorithm can use pixel centers, or corners, for the
2581   // weight. Ops depend on the scaler function passed into
2582   // compute_interpolation_weights.
2583   if (use_half_pixel_centers) {
2584     // Ops for HalfPixelScalaer.
2585     interp_weight_cost +=
2586         add_cost + mul_cost + sub_cost_float + cast_to_float_cost;
2587   } else {
2588     // Ops for LegacyScaler.
2589     interp_weight_cost += cast_to_float_cost + mul_cost;
2590   }
2591   // Cost for the interpolation is multipled by (H2 + w2), as mentioned above.
2592   ops += interp_weight_cost * (output_height + output_width);
2593 
2594   // Ops for computing the new values, done for every element. Logic is from
2595   // compute_lerp in the inner loop of resize_image which consists of:
2596   //   const float top = top_left + (top_right - top_left) * x_lerp;
2597   //   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
2598   //   return top + (bottom - top) * y_lerp;
2599   ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
2600 
2601   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2602                                  node_costs);
2603 }
2604 
PredictCropAndResize(const OpContext & op_context,NodeCosts * node_costs) const2605 Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
2606                                                   NodeCosts* node_costs) const {
2607   bool found_unknown_shapes = false;
2608 
2609   const auto method = op_context.op_info.attr().find("method");
2610   bool use_bilinear_interp;
2611   if (method == op_context.op_info.attr().end() ||
2612       method->second.s() == "bilinear") {
2613     use_bilinear_interp = true;
2614   } else if (method->second.s() == "nearest") {
2615     use_bilinear_interp = false;
2616   } else {
2617     LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
2618                     "or nearest.";
2619     return PredictCostOfAnUnknownOp(op_context, node_costs);
2620   }
2621 
2622   const int64 num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
2623   const auto crop_shape = MaybeGetMinimumShape(
2624       op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2625   const int64 crop_height = crop_shape.dim(1).size();
2626   const int64 crop_width = crop_shape.dim(2).size();
2627   const int64 output_elements = CalculateTensorElementCount(
2628       op_context.op_info.outputs(0), &found_unknown_shapes);
2629 
2630 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2631   const auto sub_cost = EIGEN_COST(scalar_difference_op<float>);
2632   const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2633   const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2634   auto div_cost = EIGEN_COST(scalar_div_cost<float>);
2635   const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2636   const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2637   auto round_cost = EIGEN_COST(scalar_round_op<float>);
2638   const auto cast_to_float_cost = Eigen::internal::functor_traits<
2639       Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2640 #undef EIGEN_COST
2641 
2642   // Computing ops following
2643   // tensorflow/core/kernels/image/crop_and_resize_op.cc at 08/25/2020. Op
2644   // calculation differs from rough estimate in implementation, as it separates
2645   // out cost per box from cost per pixel and cost per element.
2646 
2647   // Ops for variables height_scale and width_scale.
2648   int64 ops = (sub_cost * 6 + mul_cost * 2 + div_cost * 2) * num_boxes;
2649   // Ops for variable in_y.
2650   ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * num_boxes;
2651   // Ops for variable in_x (same computation across both branches).
2652   ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * crop_width *
2653          num_boxes;
2654   // Specify op_cost based on the method.
2655   if (use_bilinear_interp) {
2656     // Ops for variables top_y_index, bottom_y_index, y_lerp.
2657     ops += (floor_cost + ceil_cost + sub_cost) * crop_height * num_boxes;
2658     // Ops for variables left_x, right_x, x_lerp;
2659     ops += (floor_cost + ceil_cost + sub_cost) * crop_height * crop_width *
2660            num_boxes;
2661     // Ops for innermost loop across depth.
2662     ops +=
2663         (cast_to_float_cost * 4 + add_cost * 3 + sub_cost * 3 + mul_cost * 3) *
2664         output_elements;
2665   } else /* method == "nearest" */ {
2666     // Ops for variables closest_x_index and closest_y_index.
2667     ops += round_cost * 2 * crop_height * crop_width * num_boxes;
2668     // Ops for innermost loop across depth.
2669     ops += cast_to_float_cost * output_elements;
2670   }
2671   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2672                                  node_costs);
2673 }
2674 
2675 }  // end namespace grappler
2676 }  // end namespace tensorflow
2677