1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17 
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/attr_value_util.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/grappler/clusters/utils.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 
29 constexpr int kOpsPerMac = 2;
30 constexpr char kGuaranteeConst[] = "GuaranteeConst";
31 constexpr char kConv2d[] = "Conv2D";
32 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
33 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
34 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
35 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
36 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
37     "DepthwiseConv2dNativeBackpropFilter";
38 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
39     "DepthwiseConv2dNativeBackpropInput";
40 constexpr char kMatMul[] = "MatMul";
41 constexpr char kSparseMatMul[] = "SparseMatMul";
42 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
43 constexpr char kPlaceholder[] = "Placeholder";
44 constexpr char kIdentity[] = "Identity";
45 constexpr char kIdentityN[] = "IdentityN";
46 constexpr char kRefIdentity[] = "RefIdentity";
47 constexpr char kNoOp[] = "NoOp";
48 constexpr char kReshape[] = "Reshape";
49 constexpr char kSqueeze[] = "Squeeze";
50 constexpr char kRecv[] = "_Recv";
51 constexpr char kSend[] = "_Send";
52 constexpr char kBatchMatMul[] = "BatchMatMul";
53 constexpr char kRank[] = "Rank";
54 constexpr char kShape[] = "Shape";
55 constexpr char kShapeN[] = "ShapeN";
56 constexpr char kSize[] = "Size";
57 constexpr char kStopGradient[] = "StopGradient";
58 constexpr char kPreventGradient[] = "PreventGradient";
59 constexpr char kGather[] = "Gather";
60 constexpr char kGatherV2[] = "GatherV2";
61 constexpr char kSlice[] = "Slice";
62 constexpr char kMaxPool[] = "MaxPool";
63 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
64 constexpr char kAvgPool[] = "AvgPool";
65 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
66 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
67 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
68 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
69 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
70 // Persistent ops.
71 constexpr char kConst[] = "Const";
72 constexpr char kVariable[] = "Variable";
73 constexpr char kVariableV2[] = "VariableV2";
74 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
75 constexpr char kVarHandleOp[] = "VarHandleOp";
76 constexpr char kReadVariableOp[] = "ReadVariableOp";
77 
78 static const Costs::Duration kMinComputeTime(1);
79 
80 namespace {
81 
GetDataFormat(const OpInfo & op_info)82 string GetDataFormat(const OpInfo& op_info) {
83   string data_format = "NHWC";  // Default format.
84   if (op_info.attr().find("data_format") != op_info.attr().end()) {
85     data_format = op_info.attr().at("data_format").s();
86   }
87   return data_format;
88 }
89 
GetFilterFormat(const OpInfo & op_info)90 string GetFilterFormat(const OpInfo& op_info) {
91   string filter_format = "HWIO";  // Default format.
92   if (op_info.attr().find("filter_format") != op_info.attr().end()) {
93     filter_format = op_info.attr().at("filter_format").s();
94   }
95   return filter_format;
96 }
97 
GetPadding(const OpInfo & op_info)98 Padding GetPadding(const OpInfo& op_info) {
99   if (op_info.attr().find("padding") != op_info.attr().end() &&
100       op_info.attr().at("padding").s() == "VALID") {
101     return Padding::VALID;
102   }
103   return Padding::SAME;  // Default padding.
104 }
105 
IsTraining(const OpInfo & op_info)106 bool IsTraining(const OpInfo& op_info) {
107   if (op_info.attr().find("is_training") != op_info.attr().end() &&
108       op_info.attr().at("is_training").b()) {
109     return true;
110   }
111   return false;
112 }
113 
114 // TODO(dyoon): support non-4D tensors in the c ost functions of convolution
115 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
116 // helper functions.
GetStrides(const OpInfo & op_info)117 std::vector<int64> GetStrides(const OpInfo& op_info) {
118   if (op_info.attr().find("strides") != op_info.attr().end()) {
119     const auto strides = op_info.attr().at("strides").list().i();
120     CHECK(strides.size() == 4)
121         << "Attr strides is not a length-4 vector: " << op_info.DebugString();
122     return {strides[0], strides[1], strides[2], strides[3]};
123   }
124   return {1, 1, 1, 1};
125 }
126 
GetKernelSize(const OpInfo & op_info)127 std::vector<int64> GetKernelSize(const OpInfo& op_info) {
128   if (op_info.attr().find("ksize") != op_info.attr().end()) {
129     const auto ksize = op_info.attr().at("ksize").list().i();
130     CHECK(ksize.size() == 4)
131         << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
132     return {ksize[0], ksize[1], ksize[2], ksize[3]};
133   }
134   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
135   // {1, 1, 1, 1} in that case.
136   return {1, 1, 1, 1};
137 }
138 
GetOutputSize(const int64 input,const int64 filter,const int64 stride,const Padding & padding)139 int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
140                     const Padding& padding) {
141   // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
142   // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
143   if (padding == Padding::VALID) {
144     return (input - filter + stride) / stride;
145   } else {  // SAME.
146     return (input + stride - 1) / stride;
147   }
148 }
149 
150 // Return the output element count of a binary element-wise op considering
151 // broadcasting.
CwiseOutputElementCount(const TensorShapeProto & input_shape_1,const TensorShapeProto & input_shape_2)152 int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
153                               const TensorShapeProto& input_shape_2) {
154   bool found_unknown_shapes;
155   int rank = std::max(1, input_shape_1.dim_size());
156   TensorShapeProto output_shape =
157       MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
158 
159   if (input_shape_1.dim_size() == input_shape_2.dim_size()) {
160     auto shape_1 =
161         MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
162     auto shape_2 =
163         MaybeGetMinimumShape(input_shape_2, rank, &found_unknown_shapes);
164     if (shape_1.dim_size() == shape_2.dim_size()) {
165       for (int i = 0; i < shape_1.dim_size(); i++) {
166         output_shape.mutable_dim(i)->set_size(
167             std::max(shape_1.dim(i).size(), shape_2.dim(i).size()));
168       }
169     }
170   }
171 
172   int64 count = 1;
173   for (int i = 0; i < output_shape.dim_size(); i++) {
174     count *= output_shape.dim(i).size();
175   }
176   return count;
177 }
178 
179 }  // namespace
180 
181 // Return a minimum shape if the shape is unknown. If known, return the original
182 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)183 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
184                                       int rank, bool* found_unknown_shapes) {
185   auto shape = original_shape;
186   bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
187 
188   if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
189     *found_unknown_shapes = true;
190     VLOG(2) << "Use minimum shape because the rank is unknown.";
191     // The size of each dimension is at least 1, if unknown.
192     for (int i = shape.dim_size(); i < rank; i++) {
193       shape.add_dim()->set_size(1);
194     }
195   } else if (is_scalar) {
196     for (int i = 0; i < rank; i++) {
197       shape.add_dim()->set_size(1);
198     }
199   } else if (shape.dim_size() > rank) {
200     *found_unknown_shapes = true;
201     shape.clear_dim();
202     for (int i = 0; i < rank; i++) {
203       shape.add_dim()->set_size(original_shape.dim(i).size());
204     }
205   } else {
206     for (int i = 0; i < shape.dim_size(); i++) {
207       if (shape.dim(i).size() < 0) {
208         *found_unknown_shapes = true;
209         VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
210         // The size of each dimension is at least 1, if unknown.
211         shape.mutable_dim(i)->set_size(1);
212       }
213     }
214   }
215   return shape;
216 }
217 
OpLevelCostEstimator()218 OpLevelCostEstimator::OpLevelCostEstimator() {
219   // Syntactic sugar to build and return a lambda that takes an OpInfo and
220   // returns a cost.
221   typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
222       const;
223   auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
224     return [this, impl](const OpContext& op_context) {
225       return (this->*impl)(op_context);
226     };
227   };
228 
229   device_cost_impl_ = {
230       {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)},
231       {kConv2dBackpropFilter,
232        wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
233       {kConv2dBackpropInput,
234        wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
235       {kFusedConv2dBiasActivation,
236        wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
237       // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
238       // same although the actual meaning of the parameters are different. See
239       // comments in PredictConv2D and related functions
240       {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)},
241       {kDepthwiseConv2dNativeBackpropFilter,
242        wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
243       {kDepthwiseConv2dNativeBackpropInput,
244        wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
245       {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
246       {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
247       {kSparseTensorDenseMatMul,
248        wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul)},
249       {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
250       {kQuantizedMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
251       {kQuantizedMatMulV2, wrap(&OpLevelCostEstimator::PredictMatMul)},
252 
253       {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
254       {kGuaranteeConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
255 
256       {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
257       {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
258       {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
259 
260       {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
261       {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
262       {kIdentityN, wrap(&OpLevelCostEstimator::PredictIdentity)},
263       {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
264       {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
265       {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
266       {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
267       {kSqueeze, wrap(&OpLevelCostEstimator::PredictIdentity)},
268       {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
269       {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
270 
271       {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
272       {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
273       {kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
274       {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)},
275       {kMaxPool, wrap(&OpLevelCostEstimator::PredictMaxPool)},
276       {kMaxPoolGrad, wrap(&OpLevelCostEstimator::PredictMaxPoolGrad)},
277       {kAvgPool, wrap(&OpLevelCostEstimator::PredictAvgPool)},
278       {kAvgPoolGrad, wrap(&OpLevelCostEstimator::PredictAvgPoolGrad)},
279       {kFusedBatchNorm, wrap(&OpLevelCostEstimator::PredictFusedBatchNorm)},
280       {kFusedBatchNormGrad,
281        wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
282   };
283 
284   persistent_ops_ = {
285       kConst,       kVariable,       kVariableV2, kAutoReloadVariable,
286       kVarHandleOp, kReadVariableOp,
287   };
288 
289 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
290 
291   // Quantize = apply min and max bounds, multiply by scale factor and round.
292   const int quantize_v2_cost =
293       EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
294       EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
295 
296   elementwise_ops_ = {
297       // Unary ops alphabetically sorted
298       {"Acos", EIGEN_COST(scalar_acos_op<float>)},
299       {"Asin", EIGEN_COST(scalar_asin_op<float>)},
300       {"Atan", EIGEN_COST(scalar_atan_op<float>)},
301       {"Atan2", EIGEN_COST(scalar_quotient_op<float>) +
302                     EIGEN_COST(scalar_atan_op<float>)},
303       // For now, we use Eigen cost model for float to int16 cast as an example
304       // case; Eigen cost model is zero when src and dst types are identical,
305       // and it uses AddCost (1) when different. We may implement a separate
306       // cost functions for cast ops, using the actual input and output types.
307       {"Cast", Eigen::internal::functor_traits<
308                    Eigen::internal::scalar_cast_op<float, int16>>::Cost},
309       {"Ceil", EIGEN_COST(scalar_ceil_op<float>)},
310       {"Cos", EIGEN_COST(scalar_cos_op<float>)},
311       {"Dequantize", EIGEN_COST(scalar_product_op<float>)},
312       {"Erf", 1},
313       {"Erfc", 1},
314       {"Exp", EIGEN_COST(scalar_exp_op<float>)},
315       {"Expm1", EIGEN_COST(scalar_expm1_op<float>)},
316       {"Floor", EIGEN_COST(scalar_floor_op<float>)},
317       {"Inv", EIGEN_COST(scalar_inverse_op<float>)},
318       {"InvGrad", 1},
319       {"Lgamma", 1},
320       {"Log", EIGEN_COST(scalar_log_op<float>)},
321       {"Log1p", EIGEN_COST(scalar_log1p_op<float>)},
322       {"Neg", EIGEN_COST(scalar_opposite_op<float>)},
323       {"QuantizeV2", quantize_v2_cost},
324       {"Reciprocal", EIGEN_COST(scalar_inverse_op<float>)},
325       {"Rint", 1},
326       {"Round", EIGEN_COST(scalar_round_op<float>)},
327       {"Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>)},
328       {"Sqrt", EIGEN_COST(scalar_sqrt_op<float>)},
329       {"Square", EIGEN_COST(scalar_square_op<float>)},
330       {"Tanh", EIGEN_COST(scalar_tanh_op<float>)},
331       {"Relu", EIGEN_COST(scalar_max_op<float>)},
332       {"Sigmoid", EIGEN_COST(scalar_logistic_op<float>)},
333       {"QuantizedSigmoid", EIGEN_COST(scalar_logistic_op<float>)},
334       {"Sign", EIGEN_COST(scalar_sign_op<float>)},
335       {"Sin", EIGEN_COST(scalar_sin_op<float>)},
336       {"Tan", EIGEN_COST(scalar_tan_op<float>)},
337       // Binary ops alphabetically sorted
338       {"Add", EIGEN_COST(scalar_sum_op<float>)},
339       {"ApproximateEqual", 1},
340       {"BiasAdd", EIGEN_COST(scalar_sum_op<float>)},
341       {"QuantizedBiasAdd", EIGEN_COST(scalar_sum_op<float>)},
342       {"Div", EIGEN_COST(scalar_quotient_op<float>)},
343       {"Equal", 1},
344       {"FloorDiv", EIGEN_COST(scalar_quotient_op<float>)},
345       {"FloorMod", EIGEN_COST(scalar_mod_op<float>)},
346       {"Greater", 1},
347       {"GreaterEqual", 1},
348       {"Less", 1},
349       {"LessEqual", 1},
350       {"LogicalAnd", EIGEN_COST(scalar_boolean_and_op)},
351       {"LogicalNot", 1},
352       {"LogicalOr", EIGEN_COST(scalar_boolean_or_op)},
353       {"Maximum", EIGEN_COST(scalar_max_op<float>)},
354       {"Minimum", EIGEN_COST(scalar_min_op<float>)},
355       {"Mod", EIGEN_COST(scalar_mod_op<float>)},
356       {"Mul", EIGEN_COST(scalar_product_op<float>)},
357       {"NotEqual", 1},
358       {"QuantizedAdd", EIGEN_COST(scalar_sum_op<float>)},
359       {"QuantizedMul", EIGEN_COST(scalar_product_op<float>)},
360       {"RealDiv", EIGEN_COST(scalar_quotient_op<float>)},
361       {"ReluGrad", EIGEN_COST(scalar_max_op<float>)},
362       {"SquareDifference", 1},
363       {"Sub", EIGEN_COST(scalar_difference_op<float>)},
364       {"TruncateDiv", EIGEN_COST(scalar_quotient_op<float>)},
365       {"TruncateMod", EIGEN_COST(scalar_mod_op<float>)}};
366 
367 #undef EIGEN_COST
368 
369   // By default, use sum of memory_time and compute_time for execution_time.
370   compute_memory_overlap_ = false;
371 }
372 
PredictCosts(const OpContext & op_context) const373 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
374   const auto& op_info = op_context.op_info;
375   auto it = device_cost_impl_.find(op_info.op());
376   if (it != device_cost_impl_.end()) {
377     std::function<Costs(const OpContext&)> estimator = it->second;
378     Costs costs = estimator(op_context);
379     VLOG(1) << "Operation " << op_info.op() << " takes "
380             << costs.execution_time.count() << " ns.";
381     return costs;
382   }
383 
384   if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
385     return PredictVariable(op_context);
386   }
387 
388   if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
389     return PredictCwiseOp(op_context);
390   }
391 
392   VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
393 
394   return PredictCostOfAnUnknownOp(op_context);
395 }
396 
GetDeviceInfo(const DeviceProperties & device) const397 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
398     const DeviceProperties& device) const {
399   double gflops = -1;
400   double gb_per_sec = -1;
401 
402   if (device.type() == "CPU") {
403     // Check if vector instructions are available, and refine performance
404     // prediction based on this.
405     // Frequencies are stored in MHz in the DeviceProperties.
406     gflops = device.num_cores() * device.frequency() * 1e-3;
407     if (gb_per_sec < 0) {
408       if (device.bandwidth() > 0) {
409         gb_per_sec = device.bandwidth() / 1e6;
410       } else {
411         gb_per_sec = 32;
412       }
413     }
414   } else if (device.type() == "GPU") {
415     const string architecture = device.environment().at("architecture");
416     int cores_per_multiprocessor;
417     if (architecture < "3") {
418       // Fermi
419       cores_per_multiprocessor = 32;
420     } else if (architecture < "4") {
421       // Kepler
422       cores_per_multiprocessor = 192;
423     } else if (architecture < "6") {
424       // Maxwell
425       cores_per_multiprocessor = 128;
426     } else {
427       // Pascal (compute capability version 6) and Volta (compute capability
428       // version 7)
429       cores_per_multiprocessor = 64;
430     }
431     gflops = device.num_cores() * device.frequency() * 1e-3 *
432              cores_per_multiprocessor * kOpsPerMac;
433     if (device.bandwidth() > 0) {
434       gb_per_sec = device.bandwidth() / 1e6;
435     } else {
436       gb_per_sec = 100;
437     }
438   }
439   VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
440           << " gb_per_sec: " << gb_per_sec;
441 
442   DCHECK_LT(0, gflops) << device.DebugString();
443   DCHECK_LT(0, gb_per_sec) << device.DebugString();
444 
445   return DeviceInfo(gflops, gb_per_sec);
446 }
447 
PredictCwiseOp(const OpContext & op_context) const448 Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
449   const auto& op_info = op_context.op_info;
450   bool found_unknown_shapes = false;
451   // For unary or binary element-wise operations, op count is the element count
452   // of any input. We use the count for the largest input here to be more robust
453   // in case that the shape is unknown or partially known for other input.
454   int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
455   // If output shape is available, try use the element count calculated from
456   // that.
457   if (op_info.outputs_size() > 0) {
458     op_count = std::max(
459         op_count,
460         CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
461   }
462   // For binary ops, calculate the output shape possibly resulting from
463   // broadcasting.
464   if (op_info.inputs_size() >= 2) {
465     op_count =
466         std::max(op_count, CwiseOutputElementCount(op_info.inputs(0).shape(),
467                                                    op_info.inputs(1).shape()));
468   }
469 
470   int op_cost = 1;
471   bool is_known_elementwise_op = false;
472   auto it = elementwise_ops_.find(op_info.op());
473   if (it != elementwise_ops_.end()) {
474     op_cost = it->second;
475     is_known_elementwise_op = true;
476   } else {
477     LOG(WARNING) << "Not a cwise op: " << op_info.op();
478   }
479 
480   Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_info);
481   if (found_unknown_shapes || !is_known_elementwise_op) {
482     costs.inaccurate = true;
483   }
484   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
485   return costs;
486 }
487 
PredictCostOfAnUnknownOp(const OpContext & op_context) const488 Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
489     const OpContext& op_context) const {
490   // Don't assume the operation is cwise, return cost based on input/output size
491   // and admit that it is inaccurate...
492   auto costs = PredictOpCountBasedCost(0, op_context.op_info);
493   costs.inaccurate = true;
494   return costs;
495 }
496 
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const497 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
498     double operations, const OpInfo& op_info) const {
499   bool unknown_shapes = false;
500   const double input_size = CalculateInputSize(op_info, &unknown_shapes);
501   const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
502   Costs costs =
503       PredictOpCountBasedCost(operations, input_size, output_size, op_info);
504   costs.inaccurate = unknown_shapes;
505   costs.num_ops_with_unknown_shapes = unknown_shapes;
506   costs.max_memory = output_size;
507   return costs;
508 }
509 
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const510 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
511     double operations, double input_io_bytes, double output_io_bytes,
512     const OpInfo& op_info) const {
513   double total_io_bytes = input_io_bytes + output_io_bytes;
514   const DeviceInfo device_info = GetDeviceInfo(op_info.device());
515   if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
516       device_info.intermediate_read_gb_per_sec <= 0 ||
517       device_info.intermediate_write_gb_per_sec <= 0) {
518     VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
519             << " device type:" << op_info.device().type()
520             << " device model:" << op_info.device().model();
521   }
522 
523   Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
524   VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
525           << " Compute Time (ns):" << compute_cost.count();
526 
527   Costs::NanoSeconds memory_cost(
528       std::ceil(total_io_bytes / device_info.gb_per_sec));
529   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
530           << " Memory Time (ns):" << memory_cost.count();
531 
532   // Check if bytes > 0.  If it's not and the bandwidth is set to infinity
533   // then the result would be undefined.
534   double intermediate_read_time =
535       (input_io_bytes > 0)
536           ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
537           : 0;
538 
539   double intermediate_write_time =
540       (output_io_bytes > 0)
541           ? std::ceil(output_io_bytes /
542                       device_info.intermediate_write_gb_per_sec)
543           : 0;
544 
545   Costs::NanoSeconds intermediate_memory_cost =
546       compute_memory_overlap_
547           ? std::max(intermediate_read_time, intermediate_write_time)
548           : (intermediate_read_time + intermediate_write_time);
549   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
550           << " Intermediate Memory Time (ns):"
551           << intermediate_memory_cost.count();
552 
553   Costs costs;
554   costs.compute_time = compute_cost;
555   costs.memory_time = memory_cost;
556   costs.intermediate_memory_time = intermediate_memory_cost;
557   costs.intermediate_memory_read_time =
558       Costs::NanoSeconds(intermediate_read_time);
559   costs.intermediate_memory_write_time =
560       Costs::NanoSeconds(intermediate_write_time);
561   CombineCostsAndUpdateExecutionTime(&costs);
562   return costs;
563 }
564 
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes) const565 int64 OpLevelCostEstimator::CountConv2DOperations(
566     const OpInfo& op_info, bool* found_unknown_shapes) const {
567   return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
568 }
569 
570 // Helper to translate the positional arguments into named fields.
571 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)572 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
573     const TensorShapeProto& original_image_shape,
574     const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
575     bool* found_unknown_shapes) {
576   VLOG(2) << "op features: " << op_info.DebugString();
577   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
578   VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
579   auto image_shape =
580       MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
581   auto filter_shape =
582       MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
583   VLOG(2) << "Image shape: " << image_shape.DebugString();
584   VLOG(2) << "Filter shape: " << filter_shape.DebugString();
585 
586   int x_index, y_index, channel_index;
587   const string& data_format = GetDataFormat(op_info);
588   if (data_format == "NCHW") {
589     x_index = 2;
590     y_index = 3;
591     channel_index = 1;
592   } else {
593     // Use NHWC.
594     x_index = 1;
595     y_index = 2;
596     channel_index = 3;
597   }
598   const string& filter_format = GetFilterFormat(op_info);
599   int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
600   if (filter_format == "HWIO") {
601     filter_x_index = 0;
602     filter_y_index = 1;
603     in_channel_index = 2;
604     out_channel_index = 3;
605   } else {
606     // Use OIHW
607     filter_x_index = 2;
608     filter_y_index = 3;
609     in_channel_index = 1;
610     out_channel_index = 0;
611   }
612   int64 batch = image_shape.dim(0).size();
613   int64 ix = image_shape.dim(x_index).size();
614   int64 iy = image_shape.dim(y_index).size();
615   int64 iz = image_shape.dim(channel_index).size();
616   int64 kx = filter_shape.dim(filter_x_index).size();
617   int64 ky = filter_shape.dim(filter_y_index).size();
618   std::vector<int64> strides = GetStrides(op_info);
619   const auto padding = GetPadding(op_info);
620   int64 sx = strides[x_index];
621   int64 sy = strides[y_index];
622   int64 ox = GetOutputSize(ix, kx, sx, padding);
623   int64 oy = GetOutputSize(iy, ky, sy, padding);
624   int64 oz = filter_shape.dim(out_channel_index).size();
625   // Only check equality when both sizes are known (in other words, when
626   // neither is set to a minimum dimension size of 1).
627   if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
628     CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
629   } else {
630     iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
631   }
632   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
633       batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
634 
635   VLOG(1) << "Batch Size:" << batch;
636   VLOG(1) << "Image Dims:" << ix << "," << iy;
637   VLOG(1) << "Input Features:" << iz;
638   VLOG(1) << "Kernel Dims:" << kx << "," << ky;
639   VLOG(1) << "Output Features:" << oz;
640   VLOG(1) << "Output Dims:" << ox << "," << oy;
641   VLOG(1) << "Strides:" << sx << "," << sy;
642   VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
643   return conv_dims;
644 }
645 
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes) const646 int64 OpLevelCostEstimator::CountConv2DOperations(
647     const OpInfo& op_info, ConvolutionDimensions* conv_info,
648     bool* found_unknown_shapes) const {
649   DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
650       << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
651 
652   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
653       op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
654       found_unknown_shapes);
655 
656   //  in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
657   //  multiplier; The effective output channel depth oz_effective is
658   //  conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
659   //  Compare to Conv2D where # ops =  N x H x W x iz x oz x 2RS,
660   //  oz = oz_effective,  then Conv2D_ops / Depthwise_conv2d_native_ops = iz.
661   int64 ops = conv_dims.batch;
662   ops *= conv_dims.ox * conv_dims.oy;
663   ops *= conv_dims.kx * conv_dims.ky;
664   if (op_info.op() == kConv2d) {
665     ops *= conv_dims.iz * conv_dims.oz;
666   } else {
667     // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
668     // although ops are the same as Conv2D.
669     conv_dims.oz *= conv_dims.iz;
670     ops *= conv_dims.oz;
671   }
672   ops *= kOpsPerMac;
673 
674   if (conv_info != nullptr) {
675     *conv_info = conv_dims;
676   }
677   return ops;
678 }
679 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const680 int64 OpLevelCostEstimator::CountMatMulOperations(
681     const OpInfo& op_info, bool* found_unknown_shapes) const {
682   return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
683 }
684 
685 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes) const686 int64 OpLevelCostEstimator::CountMatMulOperations(
687     const OpInfo& op_info, MatMulDimensions* mat_mul,
688     bool* found_unknown_shapes) const {
689   double ops = 0;
690 
691   if (op_info.inputs_size() < 2) {
692     LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
693     // TODO(pcma): Try to separate invalid inputs from unknown shapes
694     *found_unknown_shapes = true;
695     return 0;
696   }
697 
698   auto& a_matrix = op_info.inputs(0);
699   auto& b_matrix = op_info.inputs(1);
700 
701   bool transpose_a = false;
702   bool transpose_b = false;
703 
704   double m_dim, n_dim, k_dim, k_dim_b = 0;
705 
706   for (const auto& item : op_info.attr()) {
707     VLOG(1) << "Key:" << item.first
708             << " Value:" << SummarizeAttrValue(item.second);
709     if (item.first == "transpose_a" && item.second.b() == true)
710       transpose_a = true;
711     if (item.first == "transpose_b" && item.second.b() == true)
712       transpose_b = true;
713   }
714   VLOG(1) << "transpose_a:" << transpose_a;
715   VLOG(1) << "transpose_b:" << transpose_b;
716   auto a_matrix_shape =
717       MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
718   auto b_matrix_shape =
719       MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
720   if (transpose_a) {
721     m_dim = a_matrix_shape.dim(1).size();
722     k_dim = a_matrix_shape.dim(0).size();
723   } else {
724     m_dim = a_matrix_shape.dim(0).size();
725     k_dim = a_matrix_shape.dim(1).size();
726   }
727   if (transpose_b) {
728     k_dim_b = b_matrix_shape.dim(1).size();
729     n_dim = b_matrix_shape.dim(0).size();
730   } else {
731     k_dim_b = b_matrix_shape.dim(0).size();
732     n_dim = b_matrix_shape.dim(1).size();
733   }
734 
735   VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
736   // Only check equality when both sizes are known (in other words, when
737   // neither is set to a minimum dimension size of 1).
738   if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
739     LOG(ERROR) << "Incompatible Matrix dimensions";
740     return ops;
741   } else {
742     // One of k_dim and k_dim_b might be 1 (mininum dimension size).
743     k_dim = std::max(k_dim, k_dim_b);
744   }
745 
746   ops = m_dim * n_dim * k_dim * 2;
747   VLOG(1) << "Operations for Matmul: " << ops;
748 
749   if (mat_mul != nullptr) {
750     mat_mul->m = m_dim;
751     mat_mul->n = n_dim;
752     mat_mul->k = k_dim;
753   }
754   return ops;
755 }
756 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const757 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
758     const OpInfo& op_info, bool* found_unknown_shapes) const {
759   if (op_info.op() != kBatchMatMul) {
760     LOG(ERROR) << "Invalid Operation: " << op_info.op();
761     // TODO(pcma): Try to separate invalid inputs from unknown shapes
762     *found_unknown_shapes = true;
763     return 0;
764   }
765   if (op_info.inputs_size() != 2) {
766     LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
767     // TODO(pcma): Try to separate invalid inputs from unknown shapes
768     *found_unknown_shapes = true;
769     return 0;
770   }
771 
772   double ops = 0;
773   const auto& a_input = op_info.inputs(0);
774   const auto& b_input = op_info.inputs(1);
775 
776   // BatchMatMul requires inputs of at least matrix shape (rank 2).
777   // The two most minor dimensions of each input are matrices that
778   // need to be multiplied together. The other dimensions determine
779   // the number of such MatMuls.  For example, if the BatchMatMul has
780   // inputs of shape:
781   //   a_input_shape = [2, 3, 4, 5]
782   //   b_input_shape = [2, 3, 5, 6]
783   // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
784   // in this BatchMatMul.
785   const int matrix_rank = 2;
786 
787   bool a_input_shape_unknown = false;
788   bool b_input_shape_unknown = false;
789 
790   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
791       a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
792       &a_input_shape_unknown);
793   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
794       b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
795       &b_input_shape_unknown);
796 
797   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
798                           (a_input.shape().dim_size() < matrix_rank) ||
799                           (b_input.shape().dim_size() < matrix_rank);
800 
801   // Compute the number of matmuls as the max indicated at each dimension
802   // by either input. Note that the shapes do not have to have
803   // the same rank due to incompleteness.
804   TensorShapeProto* bigger_rank_shape = &a_input_shape;
805   TensorShapeProto* smaller_rank_shape = &b_input_shape;
806   if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
807     bigger_rank_shape = &b_input_shape;
808     smaller_rank_shape = &a_input_shape;
809   }
810   int num_matmuls = 1;
811   for (int b_i = 0,
812            s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
813        b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
814     int b_dim = bigger_rank_shape->dim(b_i).size();
815     int s_dim = 1;
816     if (s_i >= 0) {
817       s_dim = smaller_rank_shape->dim(s_i).size();
818     }
819     num_matmuls *= std::max(b_dim, s_dim);
820   }
821 
822   // Build the MatMul. Note that values are ignored here since we are just
823   // counting ops (e.g. only shapes matter).
824   OpInfo matmul_op_info;
825   matmul_op_info.set_op("MatMul");
826 
827   AttrValue transpose_a;
828   transpose_a.set_b(false);
829   if (op_info.attr().find("adj_x") != op_info.attr().end()) {
830     transpose_a.set_b(op_info.attr().at("adj_x").b());
831   }
832   (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
833 
834   AttrValue transpose_b;
835   transpose_b.set_b(false);
836   if (op_info.attr().find("adj_y") != op_info.attr().end()) {
837     transpose_b.set_b(op_info.attr().at("adj_y").b());
838   }
839   (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
840 
841   OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
842   a_matrix->set_dtype(a_input.dtype());
843   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
844   for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
845        i < a_input_shape.dim_size(); ++i) {
846     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
847   }
848 
849   OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
850   b_matrix->set_dtype(b_input.dtype());
851   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
852   for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
853        i < b_input_shape.dim_size(); ++i) {
854     *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
855   }
856 
857   for (int i = 0; i < num_matmuls; ++i) {
858     bool matmul_unknown_shapes = false;
859     ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
860     *found_unknown_shapes |= matmul_unknown_shapes;
861   }
862   return ops;
863 }
864 
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)865 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
866                                         TensorShapeProto* tensor_shape_proto) {
867   tensor_shape_proto->Clear();
868   // First convert TensorProto into Tensor class so that it correctly parses
869   // data values within TensorProto (whether it's in int_val, int64_val,
870   // tensor_content, or anything.
871   Tensor tensor(tensor_proto.dtype());
872   if (!tensor.FromProto(tensor_proto)) {
873     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
874                  << "failed to parse TensorProto: "
875                  << tensor_proto.DebugString();
876     return false;
877   }
878   if (tensor.dims() != 1) {
879     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
880                  << "tensor is not 1D: " << tensor.dims();
881     return false;
882   }
883   // Then, convert it back to TensorProto using AsProtoField, which makes sure
884   // the data is in int_val, int64_val, or such repeated data fields, not in
885   // tensor_content.
886   TensorProto temp_tensor;
887   tensor.AsProtoField(&temp_tensor);
888 
889 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type)        \
890   do {                                                   \
891     for (const auto& value : temp_tensor.type##_val()) { \
892       tensor_shape_proto->add_dim()->set_size(value);    \
893     }                                                    \
894   } while (0)
895 
896   if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
897       tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
898     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
899   } else if (tensor.dtype() == DT_INT64) {
900     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
901   } else if (tensor.dtype() == DT_UINT32) {
902     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
903   } else if (tensor.dtype() == DT_UINT64) {
904     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
905   } else {
906     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
907                  << "Unsupported dtype: " << tensor.dtype();
908     return false;
909   }
910 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
911 
912   return true;
913 }
914 
915 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes) const916 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
917     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
918     bool* found_unknown_shapes) const {
919   int64 ops = 0;
920 
921   DCHECK(op_info.op() == kConv2dBackpropInput ||
922          op_info.op() == kDepthwiseConv2dNativeBackpropInput)
923       << "Invalid Operation: not kConv2dBackpropInput nor"
924          "kDepthwiseConv2dNativeBackpropInput";
925 
926   if (op_info.inputs_size() < 2) {
927     // TODO(pcma): Try to separate invalid inputs from unknown shapes
928     *found_unknown_shapes = true;
929     return ops;
930   }
931 
932   TensorShapeProto input_shape;
933   bool shape_found = false;
934   if (op_info.inputs(0).has_value()) {
935     const TensorProto& value = op_info.inputs(0).value();
936     shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
937   }
938   if (!shape_found && op_info.outputs_size() == 1) {
939     input_shape = op_info.outputs(0).shape();
940     shape_found = true;
941   }
942   if (!shape_found) {
943     // Set the minimum filter size that's feasible.
944     input_shape.Clear();
945     for (int i = 0; i < 4; ++i) {
946       input_shape.add_dim()->set_size(1);
947     }
948     *found_unknown_shapes = true;
949   }
950 
951   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
952       input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
953 
954   ops = conv_dims.batch;
955   ops *= conv_dims.ox * conv_dims.oy;
956   ops *= conv_dims.kx * conv_dims.ky;
957   if (op_info.op() == kConv2dBackpropInput) {
958     ops *= conv_dims.iz * conv_dims.oz;
959   } else {
960     // conv_dims always use forward path definition regardless
961     conv_dims.oz *= conv_dims.iz;
962     ops *= conv_dims.oz;
963   }
964   ops *= kOpsPerMac;
965 
966   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
967 
968   if (returned_conv_dims != nullptr) {
969     *returned_conv_dims = conv_dims;
970   }
971   return ops;
972 }
973 
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes) const974 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
975     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
976     bool* found_unknown_shapes) const {
977   int64 ops = 0;
978 
979   DCHECK(op_info.op() == kConv2dBackpropFilter ||
980          op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
981       << "Invalid Operation: not kConv2dBackpropFilter nor"
982          "kDepthwiseConv2dNativeBackpropFilter";
983 
984   TensorShapeProto filter_shape;
985   bool shape_found = false;
986   if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
987     const TensorProto& value = op_info.inputs(1).value();
988     shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
989   }
990   if (!shape_found && op_info.outputs_size() == 1) {
991     filter_shape = op_info.outputs(0).shape();
992     shape_found = true;
993   }
994   if (!shape_found) {
995     // Set the minimum filter size that's feasible.
996     filter_shape.Clear();
997     for (int i = 0; i < 4; ++i) {
998       filter_shape.add_dim()->set_size(1);
999     }
1000     *found_unknown_shapes = true;
1001   }
1002 
1003   if (op_info.inputs_size() < 1) {
1004     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1005     *found_unknown_shapes = true;
1006     return ops;
1007   }
1008   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1009       op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1010 
1011   ops = conv_dims.batch;
1012   ops *= conv_dims.ox * conv_dims.oy;
1013   ops *= conv_dims.kx * conv_dims.ky;
1014   if (op_info.op() == kConv2dBackpropFilter) {
1015     ops *= conv_dims.iz * conv_dims.oz;
1016   } else {
1017     // conv_dims always use forward path definition regardless
1018     conv_dims.oz *= conv_dims.iz;
1019     ops *= conv_dims.oz;
1020   }
1021   ops *= kOpsPerMac;
1022   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
1023 
1024   if (returned_conv_dims != nullptr) {
1025     *returned_conv_dims = conv_dims;
1026   }
1027   return ops;
1028 }
1029 
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes) const1030 int64 OpLevelCostEstimator::CalculateTensorElementCount(
1031     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
1032   VLOG(2) << "   with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1033           << tensor.shape().DebugString();
1034   int64 tensor_size = 1;
1035   int num_dims = std::max(1, tensor.shape().dim_size());
1036   auto tensor_shape =
1037       MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1038   for (const auto& dim : tensor_shape.dim()) {
1039     tensor_size *= dim.size();
1040   }
1041   return tensor_size;
1042 }
1043 
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes) const1044 int64 OpLevelCostEstimator::CalculateTensorSize(
1045     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
1046   int64 count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1047   int size = DataTypeSize(BaseType(tensor.dtype()));
1048   VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1049   return count * size;
1050 }
1051 
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes) const1052 int64 OpLevelCostEstimator::CalculateInputSize(
1053     const OpInfo& op_info, bool* found_unknown_shapes) const {
1054   int64 total_input_size = 0;
1055   for (auto& input : op_info.inputs()) {
1056     int64 input_size = CalculateTensorSize(input, found_unknown_shapes);
1057     total_input_size += input_size;
1058     VLOG(1) << "Input Size: " << input_size
1059             << " Total Input Size:" << total_input_size;
1060   }
1061   return total_input_size;
1062 }
1063 
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes) const1064 int64 OpLevelCostEstimator::CalculateLargestInputCount(
1065     const OpInfo& op_info, bool* found_unknown_shapes) const {
1066   int64 largest_input_count = 0;
1067   for (auto& input : op_info.inputs()) {
1068     int64 input_count =
1069         CalculateTensorElementCount(input, found_unknown_shapes);
1070     if (input_count > largest_input_count) {
1071       largest_input_count = input_count;
1072     }
1073     VLOG(1) << "Input Count: " << input_count
1074             << " Largest Input Count:" << largest_input_count;
1075   }
1076   return largest_input_count;
1077 }
1078 
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes) const1079 int64 OpLevelCostEstimator::CalculateOutputSize(
1080     const OpInfo& op_info, bool* found_unknown_shapes) const {
1081   int64 total_output_size = 0;
1082   // use float as default for calculations
1083   for (const auto& output : op_info.outputs()) {
1084     DataType dt = output.dtype();
1085     const auto& original_output_shape = output.shape();
1086     int64 output_size = DataTypeSize(BaseType(dt));
1087     int num_dims = std::max(1, original_output_shape.dim_size());
1088     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1089                                              found_unknown_shapes);
1090     for (const auto& dim : output_shape.dim()) {
1091       output_size *= dim.size();
1092     }
1093     total_output_size += output_size;
1094     VLOG(1) << "Output Size: " << output_size
1095             << " Total Output Size:" << total_output_size;
1096   }
1097   return total_output_size;
1098 }
1099 
PredictConv2D(const OpContext & op_context) const1100 Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
1101   const auto& op_info = op_context.op_info;
1102   bool found_unknown_shapes = false;
1103   auto costs = PredictOpCountBasedCost(
1104       CountConv2DOperations(op_info, &found_unknown_shapes), op_info);
1105   costs.inaccurate = found_unknown_shapes;
1106   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1107   return costs;
1108 }
1109 
PredictConv2DBackpropInput(const OpContext & op_context) const1110 Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
1111     const OpContext& op_context) const {
1112   const auto& op_info = op_context.op_info;
1113   bool found_unknown_shapes = false;
1114   auto costs =
1115       PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
1116                                   op_info, nullptr, &found_unknown_shapes),
1117                               op_info);
1118   costs.inaccurate = found_unknown_shapes;
1119   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1120   return costs;
1121 }
1122 
PredictConv2DBackpropFilter(const OpContext & op_context) const1123 Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
1124     const OpContext& op_context) const {
1125   const auto& op_info = op_context.op_info;
1126   bool found_unknown_shapes = false;
1127   auto costs =
1128       PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
1129                                   op_info, nullptr, &found_unknown_shapes),
1130                               op_info);
1131   costs.inaccurate = found_unknown_shapes;
1132   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1133   return costs;
1134 }
1135 
PredictFusedConv2DBiasActivation(const OpContext & op_context) const1136 Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1137     const OpContext& op_context) const {
1138   // FusedConv2DBiasActivation computes a fused kernel which implements:
1139   // 2D convolution, adds side input with separate scaling on convolution and
1140   // side inputs, then adds bias, and finally applies the ReLU activation
1141   // function to the result:
1142   //
1143   // Input -> Conv2D  ->  Add  -> BiasAdd  -> ReLU
1144   //            ^          ^         ^
1145   //          Filter   Side Input   Bias
1146   //
1147   // Note that when adding the side input, the operation multiplies the output
1148   // of Conv2D by conv_input_scale, confusingly, and the side_input by
1149   // side_input_scale.
1150   //
1151   // Note that in the special case that side_input_scale is 0, which we infer
1152   // from side_input having dimensions [], we skip that addition operation.
1153   //
1154   // For more information, see
1155   // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1156 
1157   // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
1158   // filter formats (OIHW_VECT_I).
1159   string data_format = GetDataFormat(op_context.op_info);
1160   if (data_format != "NCHW" && data_format != "NHWC") {
1161     LOG(WARNING) << "unsupported data format: " << data_format;
1162     Costs cost = Costs::ZeroCosts();
1163     cost.inaccurate = true;
1164     return cost;
1165   }
1166   string filter_format = GetFilterFormat(op_context.op_info);
1167   if (filter_format != "HWIO" && filter_format != "OIHW") {
1168     LOG(WARNING) << "unsupported filter format: " << filter_format;
1169     Costs cost = Costs::ZeroCosts();
1170     cost.inaccurate = true;
1171     return cost;
1172   }
1173 
1174   auto& conv_input = op_context.op_info.inputs(0);
1175   auto& filter = op_context.op_info.inputs(1);
1176   auto& bias = op_context.op_info.inputs(2);
1177   auto& side_input = op_context.op_info.inputs(3);
1178   auto& conv_input_scale = op_context.op_info.inputs(4);
1179   auto& side_input_scale = op_context.op_info.inputs(5);
1180 
1181   // Manually compute our convolution dimensions.
1182   bool found_unknown_shapes = false;
1183   auto dims = ConvolutionDimensionsFromInputs(
1184       conv_input.shape(), filter.shape(), op_context.op_info,
1185       &found_unknown_shapes);
1186 
1187   // Construct the shape of our output tensor from our convolution dimensions
1188   // and format, as it may not be available yet.
1189   // TODO(varomodt): should we centralize the Conv2D input/output shapes?
1190   OpInfo::TensorProperties output;
1191   if (data_format == "NCHW") {
1192     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
1193   } else if (data_format == "NHWC") {
1194     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
1195   }
1196 
1197   // Add the operations the fused op always computes.
1198   std::vector<OpContext> component_ops = {
1199       FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1200       FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1201       FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
1202       FusedChildContext(op_context, "Relu", output, {output})};
1203 
1204   // Add our side_input iff it's non-empty.
1205   if (side_input.shape().dim_size() > 0) {
1206     component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1207                                               {side_input, side_input_scale}));
1208     component_ops.push_back(
1209         FusedChildContext(op_context, "Add", output, {side_input, output}));
1210   }
1211 
1212   // Construct an op_context which definitely has our output shape.
1213   auto op_context_with_output = op_context;
1214   op_context_with_output.op_info.mutable_outputs()->Clear();
1215   *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1216 
1217   // Construct component operations and run the cost computation.
1218   auto costs = PredictFusedOp(op_context_with_output, component_ops);
1219   costs.inaccurate |= found_unknown_shapes;
1220   costs.num_ops_with_unknown_shapes = costs.inaccurate;
1221   return costs;
1222 }
1223 
PredictMatMul(const OpContext & op_context) const1224 Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
1225   const auto& op_info = op_context.op_info;
1226   bool found_unknown_shapes = false;
1227   auto costs = PredictOpCountBasedCost(
1228       CountMatMulOperations(op_info, &found_unknown_shapes), op_info);
1229   costs.inaccurate = found_unknown_shapes;
1230   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1231   return costs;
1232 }
1233 
PredictSparseTensorDenseMatMul(const OpContext & op_context) const1234 Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1235     const OpContext& op_context) const {
1236   const auto& op_info = op_context.op_info;
1237   bool found_unknown_shapes = false;
1238   // input[0]: indices in sparse matrix a
1239   // input[1]: values in sparse matrix a
1240   // input[2]: shape of matrix a
1241   // input[3]: matrix b
1242   // See
1243   // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1244   int64 num_elems_in_a =
1245       CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1246   auto b_matrix = op_info.inputs(3);
1247   auto b_matrix_shape =
1248       MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1249   int64 n_dim = b_matrix_shape.dim(1).size();
1250 
1251   // Each element in A is multiplied and added with an element from each column
1252   // in b.
1253   const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
1254 
1255   int64 a_indices_input_size =
1256       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1257   int64 a_values_input_size =
1258       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1259   int64 a_shape_input_size =
1260       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1261   int64 b_input_size =
1262       num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1263   double input_size = a_indices_input_size + a_values_input_size +
1264                       a_shape_input_size + b_input_size;
1265 
1266   double output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1267 
1268   auto costs =
1269       PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
1270   costs.inaccurate = found_unknown_shapes;
1271   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1272   costs.max_memory = output_size;
1273 
1274   return costs;
1275 }
1276 
PredictNoOp(const OpContext & op_context) const1277 Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
1278   const auto& op_info = op_context.op_info;
1279   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1280   return Costs::ZeroCosts();
1281 }
1282 
PredictIdentity(const OpContext & op_context) const1283 Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
1284   const auto& op_info = op_context.op_info;
1285   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1286   Costs result = Costs::ZeroCosts();
1287   result.max_memory = CalculateOutputSize(op_info, &result.inaccurate);
1288   result.num_ops_with_unknown_shapes = result.inaccurate;
1289   // Assign the minimum amount of time we can represent to the identity op since
1290   // it tends to be really cheap.
1291   result.compute_time = kMinComputeTime;
1292   result.execution_time = result.compute_time;
1293   return result;
1294 }
1295 
PredictVariable(const OpContext & op_context) const1296 Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
1297   const auto& op_info = op_context.op_info;
1298   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1299   Costs result = Costs::ZeroCosts();
1300   result.persistent_memory = CalculateOutputSize(op_info, &result.inaccurate);
1301   result.num_ops_with_unknown_shapes = result.inaccurate;
1302 
1303   result.compute_time = kMinComputeTime;
1304   result.execution_time = result.compute_time;
1305   return result;
1306 }
1307 
PredictBatchMatMul(const OpContext & op_context) const1308 Costs OpLevelCostEstimator::PredictBatchMatMul(
1309     const OpContext& op_context) const {
1310   const auto& op_info = op_context.op_info;
1311   bool found_unknown_shapes = false;
1312   Costs costs = PredictOpCountBasedCost(
1313       CountBatchMatMulOperations(op_info, &found_unknown_shapes), op_info);
1314   costs.inaccurate = found_unknown_shapes;
1315   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1316   return costs;
1317 }
1318 
PredictMetadata(const OpContext & op_context) const1319 Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
1320   const auto& op_info = op_context.op_info;
1321   Costs costs = Costs::ZeroCosts();
1322   costs.max_memory = CalculateOutputSize(op_info, &costs.inaccurate);
1323   costs.num_ops_with_unknown_shapes = costs.inaccurate;
1324   // Metadata operations are so cheap we assume they take the minimum amount of
1325   // time we can represent (1 ns).
1326   costs.compute_time = kMinComputeTime;
1327   costs.execution_time = costs.compute_time;
1328 
1329   return costs;
1330 }
1331 
PredictGatherOrSlice(const OpContext & op_context) const1332 Costs OpLevelCostEstimator::PredictGatherOrSlice(
1333     const OpContext& op_context) const {
1334   // Gather & Slice ops can have a very large input, but only access a small
1335   // part of it. For these op the size of the output determines the memory cost.
1336   const auto& op_info = op_context.op_info;
1337 
1338   const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1339   if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1340     Costs costs = Costs::ZeroCosts();
1341     costs.inaccurate = true;
1342     return costs;
1343   }
1344 
1345   bool unknown_shapes = false;
1346 
1347   // Each output element is a copy of some element from input.
1348   // For roofline estimate we assume each copy has a unit cost.
1349   const int64 op_count =
1350       CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
1351 
1352   const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
1353   double input_size = output_size;
1354   if (op_info.op() == "Slice") {
1355     // Add 'begin' & 'size' tensors sizes.
1356     input_size +=
1357         CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
1358         CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
1359   } else {
1360     // Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
1361     input_size +=
1362         CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
1363   }
1364 
1365   Costs costs =
1366       PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
1367   costs.inaccurate = unknown_shapes;
1368   costs.num_ops_with_unknown_shapes = unknown_shapes;
1369   costs.max_memory = output_size;
1370 
1371   return costs;
1372 }
1373 
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts) const1374 Costs OpLevelCostEstimator::PredictFusedOp(
1375     const OpContext& op_context,
1376     const std::vector<OpContext>& fused_op_contexts) const {
1377   // Note that PredictOpCountBasedCost will get the correct memory_time from
1378   // the node's inputs and outputs; but we don't want to have to re-implement
1379   // the logic for computing the operation count of each of our component
1380   // operations here; so we simply add the compute times of each component
1381   // operation, then update the execution time.
1382   Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
1383 
1384   fused_cost.compute_time = 0;
1385   fused_cost.inaccurate = false;
1386   for (auto& fused_op : fused_op_contexts) {
1387     auto op_cost = PredictCosts(fused_op);
1388 
1389     fused_cost.compute_time += op_cost.compute_time;
1390     fused_cost.inaccurate |= op_cost.inaccurate;
1391     fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
1392   }
1393 
1394   CombineCostsAndUpdateExecutionTime(&fused_cost);
1395   return fused_cost;
1396 }
1397 
1398 /* static */
FusedChildContext(const OpContext & parent,const string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)1399 OpContext OpLevelCostEstimator::FusedChildContext(
1400     const OpContext& parent, const string& op_name,
1401     const OpInfo::TensorProperties& output,
1402     const std::vector<OpInfo::TensorProperties>& inputs) {
1403   // Setup the base parameters of our new context.
1404   OpContext new_context;
1405   new_context.name = op_name;
1406   new_context.device_name = parent.device_name;
1407   new_context.op_info = parent.op_info;
1408   new_context.op_info.set_op(op_name);
1409 
1410   // Setup the inputs of our new context.
1411   new_context.op_info.mutable_inputs()->Clear();
1412   for (const auto& input : inputs) {
1413     *new_context.op_info.mutable_inputs()->Add() = input;
1414   }
1415 
1416   // Setup the output of our new context.
1417   new_context.op_info.mutable_outputs()->Clear();
1418   *new_context.op_info.mutable_outputs()->Add() = output;
1419 
1420   return new_context;
1421 }
1422 
1423 /* static */
DescribeTensor(DataType type,const std::vector<int64> & dims)1424 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
1425     DataType type, const std::vector<int64>& dims) {
1426   OpInfo::TensorProperties ret;
1427   ret.set_dtype(type);
1428 
1429   auto shape = ret.mutable_shape();
1430   for (const int dim : dims) {
1431     shape->add_dim()->set_size(dim);
1432   }
1433 
1434   return ret;
1435 }
1436 
1437 /* static */
1438 OpLevelCostEstimator::ConvolutionDimensions
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)1439 OpLevelCostEstimator::OpDimensionsFromInputs(
1440     const TensorShapeProto& original_image_shape, const OpInfo& op_info,
1441     bool* found_unknown_shapes) {
1442   VLOG(2) << "op features: " << op_info.DebugString();
1443   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
1444   auto image_shape =
1445       MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
1446   VLOG(2) << "Image shape: " << image_shape.DebugString();
1447 
1448   int x_index, y_index, channel_index;
1449   const string& data_format = GetDataFormat(op_info);
1450   if (data_format == "NCHW") {
1451     x_index = 2;
1452     y_index = 3;
1453     channel_index = 1;
1454   } else {
1455     x_index = 1;
1456     y_index = 2;
1457     channel_index = 3;
1458   }
1459   int64 batch = image_shape.dim(0).size();
1460   int64 ix = image_shape.dim(x_index).size();
1461   int64 iy = image_shape.dim(y_index).size();
1462   int64 iz = image_shape.dim(channel_index).size();
1463 
1464   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
1465   // {1, 1, 1, 1} in that case.
1466   std::vector<int64> ksize = GetKernelSize(op_info);
1467   int64 kx = ksize[x_index];
1468   int64 ky = ksize[y_index];
1469 
1470   std::vector<int64> strides = GetStrides(op_info);
1471   int64 sx = strides[x_index];
1472   int64 sy = strides[y_index];
1473   const auto padding = GetPadding(op_info);
1474 
1475   int64 ox = GetOutputSize(ix, kx, sx, padding);
1476   int64 oy = GetOutputSize(iy, ky, sy, padding);
1477   int64 oz = iz;
1478 
1479   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
1480       batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
1481   return conv_dims;
1482 }
1483 
PredictMaxPool(const OpContext & op_context) const1484 Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
1485   bool found_unknown_shapes = false;
1486   const auto& op_info = op_context.op_info;
1487   // x: op_info.inputs(0)
1488   ConvolutionDimensions dims = OpDimensionsFromInputs(
1489       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1490   // kx * ky - 1 comparisons per output (kx * xy > 1)
1491   // or 1 copy per output (kx * k1 = 1).
1492   int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
1493   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
1494 
1495   double total_input_size = 0;
1496   if (dims.ky >= dims.sy) {
1497     total_input_size =
1498         CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1499   } else {  // dims.ky < dims.sy
1500     // Vertical stride is larger than vertical kernel; assuming row-major
1501     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
1502     // others are not used for output).
1503     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
1504     total_input_size =
1505         data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
1506   }
1507   const double total_output_size =
1508       CalculateOutputSize(op_info, &found_unknown_shapes);
1509 
1510   Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1511                                         total_output_size, op_info);
1512   costs.inaccurate = found_unknown_shapes;
1513   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1514   costs.max_memory = total_output_size;
1515   return costs;
1516 }
1517 
PredictMaxPoolGrad(const OpContext & op_context) const1518 Costs OpLevelCostEstimator::PredictMaxPoolGrad(
1519     const OpContext& op_context) const {
1520   bool found_unknown_shapes = false;
1521   const auto& op_info = op_context.op_info;
1522   // x: op_info.inputs(0)
1523   // y: op_info.inputs(1)
1524   // y_grad: op_info.inputs(2)
1525   ConvolutionDimensions dims = OpDimensionsFromInputs(
1526       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1527 
1528   int64 ops = 0;
1529   if (dims.kx == 1 && dims.ky == 1) {
1530     // 1x1 window. No need to know which input was max.
1531     ops = dims.batch * dims.ix * dims.iy * dims.iz;
1532   } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
1533     // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
1534     ops = dims.batch * dims.iz *
1535           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
1536   } else {
1537     // Overlapping window: initialize with zeros, re-run maxpool, then
1538     // accumulate y_gad to proper x_grad locations.
1539     ops = dims.batch * dims.iz *
1540           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
1541   }
1542 
1543   // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
1544   // MaxPool internally.
1545   double total_input_size =
1546       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1547   total_input_size +=
1548       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1549   // Write x_grad; size equal to x.
1550   const double total_output_size =
1551       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1552 
1553   Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1554                                         total_output_size, op_info);
1555   costs.inaccurate = found_unknown_shapes;
1556   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1557   costs.max_memory = total_output_size;
1558   return costs;
1559 }
1560 
PredictAvgPool(const OpContext & op_context) const1561 Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
1562   bool found_unknown_shapes = false;
1563   const auto& op_info = op_context.op_info;
1564   // x: op_info.inputs(0)
1565   ConvolutionDimensions dims = OpDimensionsFromInputs(
1566       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1567 
1568   // kx * ky - 1 additions and 1 multiplication per output.
1569   int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
1570 
1571   double total_input_size = 0;
1572   if (dims.ky >= dims.sy) {
1573     total_input_size =
1574         CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1575   } else {  // dims.ky < dims.sy
1576     // vertical stride is larger than vertical kernel; assuming row-major
1577     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
1578     // others are not used for output).
1579     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
1580     total_input_size =
1581         data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
1582   }
1583   const double total_output_size =
1584       CalculateOutputSize(op_info, &found_unknown_shapes);
1585 
1586   Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1587                                         total_output_size, op_info);
1588   costs.inaccurate = found_unknown_shapes;
1589   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1590   costs.max_memory = total_output_size;
1591   return costs;
1592 }
1593 
PredictAvgPoolGrad(const OpContext & op_context) const1594 Costs OpLevelCostEstimator::PredictAvgPoolGrad(
1595     const OpContext& op_context) const {
1596   bool found_unknown_shapes = false;
1597   const auto& op_info = op_context.op_info;
1598   // x's shape: op_info.inputs(0)
1599   // y_grad: op_info.inputs(1)
1600 
1601   // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
1602   bool shape_found = false;
1603   TensorShapeProto x_shape;
1604   if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
1605     const TensorProto& value = op_info.inputs(0).value();
1606     shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
1607   }
1608   if (!shape_found && op_info.outputs_size() > 0) {
1609     x_shape = op_info.outputs(0).shape();
1610     shape_found = true;
1611   }
1612   if (!shape_found) {
1613     // Set the minimum shape that's feasible.
1614     x_shape.Clear();
1615     for (int i = 0; i < 4; ++i) {
1616       x_shape.add_dim()->set_size(1);
1617     }
1618     found_unknown_shapes = true;
1619   }
1620 
1621   ConvolutionDimensions dims =
1622       OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
1623 
1624   int64 ops = 0;
1625   if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
1626     // Non-overlapping window.
1627     ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
1628   } else {
1629     // Overlapping window.
1630     ops = dims.batch * dims.iz *
1631           (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
1632   }
1633 
1634   const double total_input_size =
1635       CalculateInputSize(op_info, &found_unknown_shapes);
1636   const double total_output_size =
1637       CalculateOutputSize(op_info, &found_unknown_shapes);
1638 
1639   Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1640                                         total_output_size, op_info);
1641   costs.inaccurate = found_unknown_shapes;
1642   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1643   costs.max_memory = total_output_size;
1644   return costs;
1645 }
1646 
PredictFusedBatchNorm(const OpContext & op_context) const1647 Costs OpLevelCostEstimator::PredictFusedBatchNorm(
1648     const OpContext& op_context) const {
1649   bool found_unknown_shapes = false;
1650   const auto& op_info = op_context.op_info;
1651   // x: op_info.inputs(0)
1652   // scale: op_info.inputs(1)
1653   // offset: op_info.inputs(2)
1654   // mean: op_info.inputs(3)  --> only for inference
1655   // variance: op_info.inputs(4) --> only for inference
1656   ConvolutionDimensions dims = OpDimensionsFromInputs(
1657       op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1658   const bool is_training = IsTraining(op_info);
1659 
1660   int64 ops = 0;
1661   const auto rsqrt_cost = Eigen::internal::functor_traits<
1662       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
1663   if (is_training) {
1664     ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
1665   } else {
1666     ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
1667   }
1668 
1669   const double size_nhwc =
1670       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1671   const double size_c =
1672       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1673   double total_input_size = 0.0;
1674   double total_internal_read_size = 0.0;
1675   double total_output_size = 0.0;
1676   if (is_training) {
1677     total_input_size = size_nhwc + size_c * 2;
1678     total_output_size = size_nhwc + size_c * 4;
1679     total_internal_read_size = size_nhwc;
1680   } else {
1681     total_input_size = size_nhwc + size_c * 4;
1682     total_output_size = size_nhwc;
1683   }
1684 
1685   Costs costs =
1686       PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
1687                               total_output_size, op_info);
1688   costs.inaccurate = found_unknown_shapes;
1689   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1690   costs.max_memory = total_output_size;
1691   return costs;
1692 }
1693 
PredictFusedBatchNormGrad(const OpContext & op_context) const1694 Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
1695     const OpContext& op_context) const {
1696   bool found_unknown_shapes = false;
1697   const auto& op_info = op_context.op_info;
1698   // y_backprop: op_info.inputs(0)
1699   // x: op_info.inputs(1)
1700   // scale: op_info.inputs(2)
1701   // mean: op_info.inputs(3)
1702   // variance or inverse of variance: op_info.inputs(4)
1703   ConvolutionDimensions dims = OpDimensionsFromInputs(
1704       op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
1705 
1706   int64 ops = 0;
1707   const auto rsqrt_cost = Eigen::internal::functor_traits<
1708       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
1709   ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
1710 
1711   const double size_nhwc =
1712       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1713   const double size_c =
1714       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1715   double total_input_size = size_nhwc * 2 + size_c * 2;
1716   double total_internal_read_size = size_nhwc;
1717   double total_output_size = size_nhwc * 1 + size_c * 2;
1718 
1719   Costs costs =
1720       PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
1721                               total_output_size, op_info);
1722   costs.inaccurate = found_unknown_shapes;
1723   costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1724   costs.max_memory = total_output_size;
1725   return costs;
1726 }
1727 
1728 /* static */
CombineCostsAndUpdateExecutionTime(Costs * costs) const1729 void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
1730     Costs* costs) const {
1731   if (compute_memory_overlap_) {
1732     costs->execution_time =
1733         std::max(costs->intermediate_memory_time,
1734                  std::max(costs->compute_time, costs->memory_time));
1735   } else {
1736     costs->execution_time = costs->compute_time + costs->memory_time +
1737                             costs->intermediate_memory_time;
1738   }
1739 }
1740 }  // end namespace grappler
1741 }  // end namespace tensorflow
1742