1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/device_properties.pb.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
30 namespace {
31 
32 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
33 // remove friend in the OpLevelCostEstimator class header.
34 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
35  public:
TestOpLevelCostEstimator()36   TestOpLevelCostEstimator() {
37     compute_memory_overlap_ = true;
38     device_info_ = DeviceInfo();
39   }
~TestOpLevelCostEstimator()40   ~TestOpLevelCostEstimator() override {}
41 
SetDeviceInfo(const DeviceInfo & device_info)42   void SetDeviceInfo(const DeviceInfo& device_info) {
43     device_info_ = device_info;
44   }
45 
SetComputeMemoryOverlap(bool value)46   void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
47 
48  protected:
GetDeviceInfo(const DeviceProperties & device) const49   DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
50     return device_info_;
51   }
52 
53   DeviceInfo device_info_;
54 };
55 
56 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)57 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
58   auto input = op_info->add_inputs();
59   auto shape = input->mutable_shape();
60   auto shape_rows = shape->add_dim();
61   shape_rows->set_size(rows);
62   auto shape_columns = shape->add_dim();
63   shape_columns->set_size(columns);
64   input->set_dtype(DT_FLOAT);
65 }
66 
SetCpuDevice(OpInfo * op_info)67 void SetCpuDevice(OpInfo* op_info) {
68   auto device = op_info->mutable_device();
69   device->set_type("CPU");
70   device->set_num_cores(10);
71   device->set_bandwidth(10000000);  // 10000000 KB/s = 10 GB/s
72   device->set_frequency(1000);      // 1000 Mhz = 1 GHz
73 }
74 
75 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)76 OpContext DescribeMatMul(int m, int n, int l, int k) {
77   OpContext op_context;
78   SetCpuDevice(&op_context.op_info);
79   op_context.op_info.set_op("MatMul");
80 
81   DescribeMatrix(m, l, &op_context.op_info);
82   DescribeMatrix(k, n, &op_context.op_info);
83   return op_context;
84 }
85 
86 // Wrangles the minimum number of proto fields to set up an input of
87 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)88 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
89                                 OpInfo* op_info) {
90   auto input = op_info->add_inputs();
91   input->set_dtype(dtype);
92   auto shape = input->mutable_shape();
93   for (auto d : dims) {
94     shape->add_dim()->set_size(d);
95   }
96 }
97 
98 // Wrangles the minimum number of proto fields to set up an output of
99 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)100 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
101                                  OpInfo* op_info) {
102   auto output = op_info->add_outputs();
103   output->set_dtype(dtype);
104   auto shape = output->mutable_shape();
105   for (auto d : dims) {
106     shape->add_dim()->set_size(d);
107   }
108 }
109 
110 // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)111 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
112                               const std::vector<int>& dims_b) {
113   OpContext op_context;
114   SetCpuDevice(&op_context.op_info);
115   op_context.op_info.set_op("BatchMatMul");
116 
117   DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
118   DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
119   return op_context;
120 }
121 
122 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)123 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
124                                           const std::vector<int>& dims_b,
125                                           const std::vector<int>& dims_out) {
126   OpContext op_context;
127   SetCpuDevice(&op_context.op_info);
128   op_context.op_info.set_op("SparseTensorDenseMatMul");
129 
130   DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
131   DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
132   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
133   DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
134   DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
135   return op_context;
136 }
137 
138 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
139 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)140 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
141   auto shape = tensor->mutable_shape();
142   shape->add_dim()->set_size(dim0);
143   tensor->set_dtype(DT_FLOAT);
144 }
145 
146 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
147 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)148 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
149                       OpInfo::TensorProperties* tensor) {
150   auto shape = tensor->mutable_shape();
151   shape->add_dim()->set_size(dim0);
152   shape->add_dim()->set_size(dim1);
153   shape->add_dim()->set_size(dim2);
154   shape->add_dim()->set_size(dim3);
155   tensor->set_dtype(DT_FLOAT);
156 }
157 
158 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
159 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
160 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)161 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
162                               int kx, int ky, int oz) {
163   OpContext op_context;
164   SetCpuDevice(&op_context.op_info);
165   op_context.op_info.set_op("Conv2D");
166 
167   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
168   DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
169 
170   return op_context;
171 }
172 
173 // Describe DepthwiseConvolution constructs an OpContext for a
174 // DepthwiseConv2dNative applied to an input
175 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
176 // (kx, ky, iz2, cm). cm is channel multiplier
177 
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)178 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
179                                         int iz2, int kx, int ky, int cm) {
180   OpContext op_context;
181   SetCpuDevice(&op_context.op_info);
182   op_context.op_info.set_op("DepthwiseConv2dNative");
183 
184   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
185   DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
186 
187   return op_context;
188 }
189 
190 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
191 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
192 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
193 // bias tensor with shape (oz), a side input tensor with shape
194 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
195 // shape (1).
196 //
197 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)198 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
199                                             int iz2, int kx, int ky, int ox,
200                                             int oy, int oz, bool has_side_input,
201                                             const string& data_format,
202                                             const string& filter_format) {
203   OpContext op_context;
204   SetCpuDevice(&op_context.op_info);
205   op_context.op_info.set_op("FusedConv2DBiasActivation");
206   auto* attr_data_format = op_context.op_info.mutable_attr();
207   SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
208   auto* attr_filter_format = op_context.op_info.mutable_attr();
209   SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
210   if (data_format == "NHWC") {
211     DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
212   } else {
213     // Use the NCHW format.
214     DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
215   }
216   if (filter_format == "HWIO") {
217     DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
218   } else {
219     // Use the OIHW format.
220     DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
221   }
222   DescribeTensor1D(oz, op_context.op_info.add_inputs());
223 
224   // Add the side_input, if any.
225   auto side_input = op_context.op_info.add_inputs();
226   if (has_side_input) {
227     if (data_format == "NHWC") {
228       DescribeTensor4D(batch, ox, oy, oz, side_input);
229     } else {
230       DescribeTensor4D(batch, oz, ox, oy, side_input);
231     }
232   }
233 
234   // Add the scaling tensors.
235   DescribeTensor1D(1, op_context.op_info.add_inputs());
236   DescribeTensor1D(1, op_context.op_info.add_inputs());
237 
238   return op_context;
239 }
240 
241 // DescribeUnaryOp constructs an OpContext for the given operation applied to
242 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)243 OpContext DescribeUnaryOp(const string& op, int size1) {
244   OpContext op_context;
245   SetCpuDevice(&op_context.op_info);
246   op_context.op_info.set_op(op);
247 
248   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
249   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
250 
251   return op_context;
252 }
253 
254 // DescribeBinaryOp constructs an OpContext for the given operation applied to
255 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
256 // (2 * size1, size2, 1, 1).
257 //
258 // The choice of dimension here is arbitrary, and is used strictly to test the
259 // cost model for applying elementwise operations to tensors with unequal
260 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)261 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
262   OpContext op_context;
263   SetCpuDevice(&op_context.op_info);
264   op_context.op_info.set_op(op);
265 
266   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
267   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
268   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
269 
270   return op_context;
271 }
272 
273 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
274 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
275 // according to the constraint that the bias must be 1D with size equal to that
276 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)277 OpContext DescribeBiasAdd(int size1, int size2) {
278   OpContext op_context;
279   SetCpuDevice(&op_context.op_info);
280   op_context.op_info.set_op("BiasAdd");
281 
282   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
283   DescribeTensor1D(size1, op_context.op_info.add_inputs());
284   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
285 
286   return op_context;
287 }
288 
GetOutputSize(const int x,const int k,const int s,const string & padding)289 int GetOutputSize(const int x, const int k, const int s,
290                   const string& padding) {
291   if (padding == "SAME") {
292     return (x + s - 1) / s;
293   } else {
294     return (x - k + s) / s;
295   }
296 }
297 
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)298 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
299                                       const std::vector<int>& ksize,
300                                       const std::vector<int>& strides,
301                                       const string& data_format,
302                                       const string& padding) {
303   // h, w, and c indices: default with NHWC.
304   int h_index = 1;
305   int w_index = 2;
306   int c_index = 3;
307   if (data_format == "NCHW") {
308     h_index = 2;
309     w_index = 3;
310     c_index = 1;
311   }
312   // Extract parameters.
313   int n = input[0];
314   int h = input[h_index];
315   int w = input[w_index];
316   int c = input[c_index];
317   int sx = strides[h_index];
318   int sy = strides[w_index];
319   int kx = ksize[h_index];
320   int ky = ksize[w_index];
321 
322   // Output activation size: default with VALID padding.
323   int ho = GetOutputSize(h, kx, sx, padding);
324   int wo = GetOutputSize(w, ky, sy, padding);
325 
326   std::vector<int> output;
327   if (data_format == "NHWC") {
328     output = {n, ho, wo, c};
329   } else {
330     output = {n, c, ho, wo};
331   }
332   return output;
333 }
334 
335 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64> & shape,const std::vector<int64> values,const bool tensor_content,TensorProto * tensor_proto)336 void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
337                     const std::vector<int64> values, const bool tensor_content,
338                     TensorProto* tensor_proto) {
339   tensor_proto->Clear();
340   TensorProto temp_tensor_proto;
341   temp_tensor_proto.set_dtype(dtype);
342   for (const auto& x : shape) {
343     temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
344   }
345   for (const auto& x : values) {
346     if (dtype == DT_INT64) {
347       temp_tensor_proto.add_int64_val(x);
348     } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
349                dtype == DT_UINT8) {
350       temp_tensor_proto.add_int_val(x);
351     } else if (dtype == DT_UINT32) {
352       temp_tensor_proto.add_uint32_val(x);
353     } else if (dtype == DT_UINT64) {
354       temp_tensor_proto.add_uint64_val(x);
355     } else {
356       CHECK(false) << "Unsupported dtype: " << dtype;
357     }
358   }
359   Tensor tensor(dtype);
360   CHECK(tensor.FromProto(temp_tensor_proto));
361   if (tensor_content) {
362     tensor.AsProtoTensorContent(tensor_proto);
363   } else {
364     tensor.AsProtoField(tensor_proto);
365   }
366 }
367 
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)368 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
369                             const std::vector<int>& ksize,
370                             const std::vector<int>& strides,
371                             const string& data_format, const string& padding) {
372   OpContext op_context;
373   auto& op_info = op_context.op_info;
374   SetCpuDevice(&op_info);
375   op_info.set_op(op_name);
376 
377   const std::vector<int> y =
378       GetPoolingOutputSize(x, ksize, strides, data_format, padding);
379   if (op_name == "AvgPool" || op_name == "MaxPool") {
380     // input: x, output: y.
381     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
382     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
383   } else if (op_name == "AvgPoolGrad") {
384     // input: x's shape, y_grad, output: x_grad.
385     DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
386     auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
387     GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
388                    /*tensor_content=*/false, tensor_proto);
389     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
390     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
391   } else if (op_name == "MaxPoolGrad") {
392     // input: x, y, y_grad, output: x_grad.
393     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
394     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
395     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
396     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
397   }
398   auto* attr = op_info.mutable_attr();
399   SetAttrValue(data_format, &(*attr)["data_format"]);
400   SetAttrValue(padding, &(*attr)["padding"]);
401   SetAttrValue(strides, &(*attr)["strides"]);
402   SetAttrValue(ksize, &(*attr)["ksize"]);
403   return op_context;
404 }
405 
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)406 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
407                                  const std::vector<int>& x,
408                                  const string& data_format) {
409   // First, get MaxPool op info with unit stride and unit window.
410   OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
411                                            {1, 1, 1, 1}, data_format, "SAME");
412   auto& op_info = op_context.op_info;
413   // Override op name.
414   if (is_grad) {
415     op_info.set_op("FusedBatchNormGrad");
416   } else {
417     op_info.set_op("FusedBatchNorm");
418   }
419 
420   // Add additional input output tensors.
421   if (is_grad) {
422     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
423   }
424   int num_1d_inputs = is_grad ? 3 : 4;
425   for (int i = 0; i < num_1d_inputs; i++) {
426     auto* tensor = op_info.add_inputs();
427     auto* shape = tensor->mutable_shape();
428     shape->add_dim()->set_size(x[3]);
429     tensor->set_dtype(DT_FLOAT);
430   }
431   for (int i = 0; i < 4; i++) {
432     auto* tensor = op_info.add_outputs();
433     auto* shape = tensor->mutable_shape();
434     shape->add_dim()->set_size(x[3]);
435     tensor->set_dtype(DT_FLOAT);
436   }
437 
438   // Delete unnecessary attr.
439   auto* attr = op_context.op_info.mutable_attr();
440   attr->erase("ksize");
441   attr->erase("strides");
442   attr->erase("padding");
443 
444   // Additional attrs for FusedBatchNorm.
445   SetAttrValue(is_training, &(*attr)["is_training"]);
446 
447   return op_context;
448 }
449 }  // namespace
450 
451 class OpLevelCostEstimatorTest : public ::testing::Test {
452  protected:
PredictCosts(const OpContext & op_context) const453   Costs PredictCosts(const OpContext& op_context) const {
454     return estimator_.PredictCosts(op_context);
455   }
456 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const457   int64 CountMatMulOperations(const OpInfo& op_info,
458                               bool* found_unknown_shapes) const {
459     return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
460   }
461 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const462   int64 CountBatchMatMulOperations(const OpInfo& op_info,
463                                    bool* found_unknown_shapes) const {
464     return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
465   }
466 
SetComputeMemoryOverlap(bool value)467   void SetComputeMemoryOverlap(bool value) {
468     estimator_.compute_memory_overlap_ = value;
469   }
470 
ValidateOpDimensionsFromImputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)471   void ValidateOpDimensionsFromImputs(const int n, const int h, const int w,
472                                       const int c, const int kx, const int ky,
473                                       const int sx, const int sy,
474                                       const string& data_format,
475                                       const string& padding) {
476     OpContext op_context;
477     int ho;
478     int wo;
479     if (data_format == "NHWC") {
480       op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
481                                      {1, sx, sy, 1}, "NHWC", padding);
482       ho = op_context.op_info.outputs(0).shape().dim(1).size();
483       wo = op_context.op_info.outputs(0).shape().dim(2).size();
484     } else {
485       op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
486                                      {1, 1, sx, sy}, "NCHW", padding);
487       ho = op_context.op_info.outputs(0).shape().dim(2).size();
488       wo = op_context.op_info.outputs(0).shape().dim(3).size();
489     }
490 
491     bool found_unknown_shapes;
492     auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
493         op_context.op_info.inputs(0).shape(), op_context.op_info,
494         &found_unknown_shapes);
495     Padding padding_enum;
496     if (padding == "VALID") {
497       padding_enum = Padding::VALID;
498     } else {
499       padding_enum = Padding::SAME;
500     }
501     EXPECT_EQ(n, dims.batch);
502     EXPECT_EQ(h, dims.ix);
503     EXPECT_EQ(w, dims.iy);
504     EXPECT_EQ(c, dims.iz);
505     EXPECT_EQ(kx, dims.kx);
506     EXPECT_EQ(ky, dims.ky);
507     EXPECT_EQ(sx, dims.sx);
508     EXPECT_EQ(sy, dims.sy);
509     EXPECT_EQ(ho, dims.ox);
510     EXPECT_EQ(wo, dims.oy);
511     EXPECT_EQ(c, dims.oz);
512     EXPECT_EQ(padding_enum, dims.padding);
513   }
514 
515   OpLevelCostEstimator estimator_;
516 };
517 
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)518 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
519   OpContext op_context;
520   SetCpuDevice(&op_context.op_info);
521   std::unordered_set<string> persisent_ops = {
522       "Const",       "Variable",       "VariableV2", "AutoReloadVariable",
523       "VarHandleOp", "ReadVariableOp",
524   };
525   // Minmum cost for all persistent ops.
526   for (const auto& op : persisent_ops) {
527     op_context.op_info.set_op(op);
528     auto cost = estimator_.PredictCosts(op_context);
529     EXPECT_EQ(Costs::Duration(0), cost.memory_time);
530     EXPECT_EQ(Costs::Duration(1), cost.compute_time);
531     EXPECT_EQ(Costs::Duration(1), cost.execution_time);
532     EXPECT_EQ(1, cost.num_ops_total);
533     EXPECT_FALSE(cost.inaccurate);
534     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
535   }
536 }
537 
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)538 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
539   OpContext op_context;
540   SetCpuDevice(&op_context.op_info);
541   op_context.op_info.set_op("Gather");
542 
543   // Huge first input shouldn't affect Gather execution and memory costs.
544   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
545   DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
546   DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
547 
548   auto cost = estimator_.PredictCosts(op_context);
549   EXPECT_EQ(Costs::Duration(130), cost.memory_time);
550   EXPECT_EQ(Costs::Duration(16), cost.compute_time);
551   EXPECT_EQ(Costs::Duration(146), cost.execution_time);
552   EXPECT_EQ(1, cost.num_ops_total);
553   EXPECT_FALSE(cost.inaccurate);
554   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
555 }
556 
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)557 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
558   OpContext op_context;
559   SetCpuDevice(&op_context.op_info);
560   op_context.op_info.set_op("Gather");
561 
562   // Huge first input shouldn't affect Gather execution and memory costs.
563   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
564   DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
565 
566   auto cost = estimator_.PredictCosts(op_context);
567   EXPECT_EQ(Costs::Duration(0), cost.memory_time);
568   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
569   EXPECT_EQ(Costs::Duration(0), cost.execution_time);
570   EXPECT_EQ(1, cost.num_ops_total);
571   EXPECT_TRUE(cost.inaccurate);
572   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
573 }
574 
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)575 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
576   OpContext op_context;
577   SetCpuDevice(&op_context.op_info);
578   op_context.op_info.set_op("Slice");
579 
580   // Huge first input shouldn't affect Slice execution and memory costs.
581   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
582   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
583   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
584   DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
585 
586   auto cost = estimator_.PredictCosts(op_context);
587   EXPECT_EQ(Costs::Duration(81), cost.memory_time);
588   EXPECT_EQ(Costs::Duration(10), cost.compute_time);
589   EXPECT_EQ(Costs::Duration(91), cost.execution_time);
590   EXPECT_EQ(1, cost.num_ops_total);
591   EXPECT_FALSE(cost.inaccurate);
592   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
593 }
594 
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)595 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
596   auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
597   EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
598   EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
599   EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
600   EXPECT_EQ(1, cost.num_ops_total);
601   EXPECT_FALSE(cost.inaccurate);
602   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
603 }
604 
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)605 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
606   auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
607   EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
608   EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
609   EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
610   EXPECT_EQ(1, cost.num_ops_total);
611   EXPECT_FALSE(cost.inaccurate);
612   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
613 }
614 
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)615 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
616   auto cost =
617       PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
618   EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
619   EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
620   EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
621   EXPECT_EQ(1, cost.num_ops_total);
622   EXPECT_FALSE(cost.inaccurate);
623   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
624 }
625 
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)626 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
627   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
628   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
629   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
630   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
631   EXPECT_EQ(1, cost.num_ops_total);
632   EXPECT_TRUE(cost.inaccurate);
633   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
634 }
635 
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)636 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
637   SetComputeMemoryOverlap(true);
638   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
639   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
640   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
641   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);  // max(2000, 200)
642   EXPECT_EQ(1, cost.num_ops_total);
643   EXPECT_TRUE(cost.inaccurate);
644   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
645   SetComputeMemoryOverlap(false);  // Set it back to default.
646 }
647 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)648 TEST_F(OpLevelCostEstimatorTest,
649        FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
650   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
651       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
652       "NCHW", "HWIO"));
653   EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
654   EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
655   EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
656   EXPECT_EQ(1, cost.num_ops_total);
657   EXPECT_FALSE(cost.inaccurate);
658   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
659 }
660 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)661 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
662   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
663       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
664       "NCHW", "HWIO"));
665   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
666   EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
667   EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
668   EXPECT_EQ(1, cost.num_ops_total);
669   EXPECT_FALSE(cost.inaccurate);
670   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
671 }
672 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)673 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
674   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
675       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
676       "NCHW", "OIHW"));
677   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
678   EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
679   EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
680   EXPECT_EQ(1, cost.num_ops_total);
681   EXPECT_FALSE(cost.inaccurate);
682   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
683 }
684 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)685 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
686   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
687       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
688       "NHWC", "HWIO"));
689   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
690   EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
691   EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
692   EXPECT_EQ(1, cost.num_ops_total);
693   EXPECT_FALSE(cost.inaccurate);
694   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
695 }
696 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)697 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
698   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
699       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
700       "NHWC", "OIHW"));
701   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
702   EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
703   EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
704   EXPECT_EQ(1, cost.num_ops_total);
705   EXPECT_FALSE(cost.inaccurate);
706   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
707 }
708 
709 // TODO(yaozhang): Update once NCHW_VECT_C is supported.
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)710 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
711   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
712       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
713       "NCHW_VECT_C", "OIHW"));
714   EXPECT_EQ(Costs::Duration(0), cost.memory_time);
715   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
716   EXPECT_EQ(Costs::Duration(0), cost.execution_time);
717   EXPECT_EQ(1, cost.num_ops_total);
718   EXPECT_TRUE(cost.inaccurate);
719   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
720 }
721 
722 // TODO(yaozhang): Update once OIHW_VECT_I is supported.
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)723 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
724   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
725       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
726       "NCHW", "OIHW_VECT_I"));
727   EXPECT_EQ(Costs::Duration(0), cost.memory_time);
728   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
729   EXPECT_EQ(Costs::Duration(0), cost.execution_time);
730   EXPECT_EQ(1, cost.num_ops_total);
731   EXPECT_TRUE(cost.inaccurate);
732   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
733 }
734 
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)735 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
736   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
737   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
738   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
739   EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
740   EXPECT_EQ(1, cost.num_ops_total);
741   EXPECT_FALSE(cost.inaccurate);
742   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
743 }
744 
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)745 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
746   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
747   EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
748   EXPECT_EQ(Costs::Duration(400), cost.compute_time);
749   EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
750   EXPECT_EQ(1, cost.num_ops_total);
751   EXPECT_FALSE(cost.inaccurate);
752   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
753 }
754 
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)755 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
756   auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
757   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
758   EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
759   EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
760   EXPECT_EQ(1, cost.num_ops_total);
761   EXPECT_FALSE(cost.inaccurate);
762   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
763 }
764 
TEST_F(OpLevelCostEstimatorTest,ReluExecutionTime)765 TEST_F(OpLevelCostEstimatorTest, ReluExecutionTime) {
766   auto cost = PredictCosts(DescribeUnaryOp("Relu", 1000));
767   EXPECT_EQ(Costs::Duration(800), cost.memory_time);
768   EXPECT_EQ(Costs::Duration(100), cost.compute_time);
769   EXPECT_EQ(Costs::Duration(900), cost.execution_time);
770   EXPECT_EQ(1, cost.num_ops_total);
771   EXPECT_FALSE(cost.inaccurate);
772   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
773 }
774 
TEST_F(OpLevelCostEstimatorTest,CastExecutionTime)775 TEST_F(OpLevelCostEstimatorTest, CastExecutionTime) {
776   auto cost = PredictCosts(DescribeUnaryOp("Cast", 1000));
777   EXPECT_EQ(Costs::Duration(800), cost.memory_time);
778   EXPECT_EQ(Costs::Duration(100), cost.compute_time);
779   EXPECT_EQ(Costs::Duration(900), cost.execution_time);
780   EXPECT_EQ(1, cost.num_ops_total);
781   EXPECT_FALSE(cost.inaccurate);
782   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
783 }
784 
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)785 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
786   {
787     auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
788     EXPECT_EQ(1, cost.num_ops_total);
789     EXPECT_FALSE(cost.inaccurate);
790     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
791   }
792   {
793     auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
794     EXPECT_EQ(1, cost.num_ops_total);
795     EXPECT_TRUE(cost.inaccurate);
796     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
797   }
798   {
799     auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
800     EXPECT_EQ(1, cost.num_ops_total);
801     EXPECT_TRUE(cost.inaccurate);
802     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
803   }
804   {
805     auto cost =
806         PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
807     EXPECT_EQ(1, cost.num_ops_total);
808     EXPECT_FALSE(cost.inaccurate);
809     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
810   }
811   {
812     auto cost =
813         PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
814     EXPECT_EQ(1, cost.num_ops_total);
815     EXPECT_TRUE(cost.inaccurate);
816     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
817   }
818 }
819 
TEST_F(OpLevelCostEstimatorTest,BatchMatMul)820 TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
821   {
822     auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
823     EXPECT_EQ(1, cost.num_ops_total);
824     EXPECT_TRUE(cost.inaccurate);
825     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
826   }
827   {
828     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
829     EXPECT_EQ(1, cost.num_ops_total);
830     EXPECT_TRUE(cost.inaccurate);
831     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
832   }
833   {
834     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
835     EXPECT_EQ(1, cost.num_ops_total);
836     EXPECT_FALSE(cost.inaccurate);
837     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
838   }
839   {
840     auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
841     EXPECT_EQ(1, cost.num_ops_total);
842     EXPECT_FALSE(cost.inaccurate);
843     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
844   }
845   {
846     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
847     EXPECT_EQ(1, cost.num_ops_total);
848     EXPECT_FALSE(cost.inaccurate);
849     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
850   }
851   bool matmul_inaccurate = false;
852   bool batch_matmul_inaccurate = false;
853   EXPECT_EQ(
854       CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
855                             &matmul_inaccurate),
856       CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
857                                  &batch_matmul_inaccurate));
858   EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
859   EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
860                                        &matmul_inaccurate),
861             CountBatchMatMulOperations(
862                 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
863                 &batch_matmul_inaccurate));
864   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
865   EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
866                                        &matmul_inaccurate),
867             CountBatchMatMulOperations(
868                 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
869                 &batch_matmul_inaccurate));
870   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
871 }
872 
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)873 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
874   // Unknown shape cases
875   {
876     auto cost =
877         PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
878     EXPECT_EQ(1, cost.num_ops_total);
879     EXPECT_TRUE(cost.inaccurate);
880     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
881   }
882   {
883     auto cost =
884         PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
885     EXPECT_EQ(1, cost.num_ops_total);
886     EXPECT_TRUE(cost.inaccurate);
887     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
888   }
889   {
890     auto cost =
891         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
892     EXPECT_EQ(1, cost.num_ops_total);
893     EXPECT_TRUE(cost.inaccurate);
894     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
895   }
896   {
897     auto cost =
898         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
899     EXPECT_EQ(1, cost.num_ops_total);
900     EXPECT_TRUE(cost.inaccurate);
901     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
902   }
903   // Known shape cases
904   {
905     auto cost = PredictCosts(
906         DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
907     EXPECT_EQ(1, cost.num_ops_total);
908     EXPECT_FALSE(cost.inaccurate);
909     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
910     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
911     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
912   }
913   {
914     // Same cost as above case because cost does not depend on k_dim
915     auto cost = PredictCosts(
916         DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
917     EXPECT_EQ(1, cost.num_ops_total);
918     EXPECT_FALSE(cost.inaccurate);
919     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
920     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
921     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
922   }
923 }
924 
ExpectTensorShape(const std::vector<int64> & expected,const TensorShapeProto & tensor_shape_proto)925 void ExpectTensorShape(const std::vector<int64>& expected,
926                        const TensorShapeProto& tensor_shape_proto) {
927   TensorShape tensor_shape_expected(expected);
928   TensorShape tensor_shape(tensor_shape_proto);
929 
930   LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString();
931   LOG(INFO) << "TensorShape: " << tensor_shape.DebugString();
932   EXPECT_TRUE(tensor_shape_expected == tensor_shape);
933 }
934 
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)935 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
936   TensorProto tensor_proto;
937   TensorShapeProto tensor_shape_proto;
938 
939   // Dimension larger than max value; should fail while converting to
940   // Tensor class.
941   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
942   EXPECT_FALSE(
943       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
944 
945   tensor_proto.Clear();
946   // Expect only 1D shape.
947   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
948   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
949   EXPECT_FALSE(
950       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
951 
952   // Expect only handle integer data types.
953   GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
954   EXPECT_FALSE(
955       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
956 
957   // Check GetTensorShapeProtoFromTensorProto() returns correct values.
958   {
959     std::vector<int64> shape_expected = {10, 20, 30, 40};
960     GetTensorProto(DT_INT32, {4}, shape_expected,
961                    /*tensor_content=*/false, &tensor_proto);
962     EXPECT_TRUE(
963         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
964     ExpectTensorShape(shape_expected, tensor_shape_proto);
965   }
966 
967   {
968     std::vector<int64> shape_expected = {40, 20, 90, 40};
969     GetTensorProto(DT_INT64, {4}, shape_expected,
970                    /*tensor_content=*/false, &tensor_proto);
971     EXPECT_TRUE(
972         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
973     ExpectTensorShape(shape_expected, tensor_shape_proto);
974   }
975 
976   {
977     std::vector<int64> shape_expected = {10, 20, 30, 40};
978     GetTensorProto(DT_INT32, {4}, shape_expected,
979                    /*tensor_content=*/true, &tensor_proto);
980     EXPECT_TRUE(
981         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
982     ExpectTensorShape(shape_expected, tensor_shape_proto);
983   }
984 
985   {
986     std::vector<int64> shape_expected = {40, 20, 90, 40};
987     GetTensorProto(DT_INT64, {4}, shape_expected,
988                    /*tensor_content=*/true, &tensor_proto);
989     EXPECT_TRUE(
990         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
991     ExpectTensorShape(shape_expected, tensor_shape_proto);
992   }
993 }
994 
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)995 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
996   std::vector<string> paddings = {"VALID", "SAME"};
997   std::vector<string> formats = {"NHWC", "NCHW"};
998   for (const auto& p : paddings) {
999     for (const auto& f : formats) {
1000       // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1001       ValidateOpDimensionsFromImputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1002       ValidateOpDimensionsFromImputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1003       ValidateOpDimensionsFromImputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1004       ValidateOpDimensionsFromImputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1005     }
1006   }
1007 }
1008 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1009 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1010   auto predict_max_pool = [this](const int n, const int in, const int c,
1011                                  const int k, const int s,
1012                                  const string& padding) -> Costs {
1013     OpContext op_context = DescribePoolingOp(
1014         "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1015     return estimator_.PredictCosts(op_context);
1016   };
1017 
1018   {
1019     // Typical 3xz3 window with 2x2 stride.
1020     auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1021     EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1022     EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1023     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1024     EXPECT_EQ(1, costs.num_ops_total);
1025     EXPECT_FALSE(costs.inaccurate);
1026     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1027   }
1028   {
1029     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1030     auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1031     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1032     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1033     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1034     EXPECT_EQ(1, costs.num_ops_total);
1035     EXPECT_FALSE(costs.inaccurate);
1036     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1037   }
1038   {
1039     // 2x2 window with 3x3 stride.
1040     auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1041     EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1042     EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1043     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1044     EXPECT_EQ(1, costs.num_ops_total);
1045     EXPECT_FALSE(costs.inaccurate);
1046     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1047   }
1048 }
1049 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1050 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1051   auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1052                                       const int k, const int s,
1053                                       const string& padding) -> Costs {
1054     OpContext op_context =
1055         DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1056                           {1, s, s, 1}, "NHWC", padding);
1057     return estimator_.PredictCosts(op_context);
1058   };
1059 
1060   {
1061     // Typical 3x3 window with 2x2 stride.
1062     auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1063     EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1064     EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1065     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1066     EXPECT_EQ(1, costs.num_ops_total);
1067     EXPECT_FALSE(costs.inaccurate);
1068     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1069   }
1070   {
1071     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1072     auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1073     EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1074     EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1075     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1076     EXPECT_EQ(1, costs.num_ops_total);
1077     EXPECT_FALSE(costs.inaccurate);
1078     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1079   }
1080   {
1081     // 2x2 window with 3x3 stride.
1082     auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1083     EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1084     EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1085     EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1086     EXPECT_EQ(1, costs.num_ops_total);
1087     EXPECT_FALSE(costs.inaccurate);
1088     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1089   }
1090 }
1091 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1092 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1093   auto predict_avg_pool = [this](const int n, const int in, const int c,
1094                                  const int k, const int s,
1095                                  const string& padding) -> Costs {
1096     OpContext op_context = DescribePoolingOp(
1097         "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1098     return estimator_.PredictCosts(op_context);
1099   };
1100 
1101   {
1102     // Typical 3x3 window with 2x2 stride.
1103     auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1104     EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1105     EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1106     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1107     EXPECT_EQ(1, costs.num_ops_total);
1108     EXPECT_FALSE(costs.inaccurate);
1109     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1110   }
1111   {
1112     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1113     auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1114     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1115     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1116     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1117     EXPECT_EQ(1, costs.num_ops_total);
1118     EXPECT_FALSE(costs.inaccurate);
1119     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1120   }
1121   {
1122     // 2x2 window with 3x3 stride.
1123     auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1124     EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1125     EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1126     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1127     EXPECT_EQ(1, costs.num_ops_total);
1128     EXPECT_FALSE(costs.inaccurate);
1129     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1130   }
1131 }
1132 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1133 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1134   auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1135                                       const int k, const int s,
1136                                       const string& padding) -> Costs {
1137     OpContext op_context =
1138         DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1139                           {1, s, s, 1}, "NHWC", padding);
1140     return estimator_.PredictCosts(op_context);
1141   };
1142 
1143   {
1144     // Typical 3xz3 window with 2x2 stride.
1145     auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1146     EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1147     EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1148     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1149     EXPECT_EQ(1, costs.num_ops_total);
1150     EXPECT_FALSE(costs.inaccurate);
1151     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1152   }
1153   {
1154     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1155     auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1156     EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1157     EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1158     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1159     EXPECT_EQ(1, costs.num_ops_total);
1160     EXPECT_FALSE(costs.inaccurate);
1161     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1162   }
1163   {
1164     // 2x2 window with 3x3 stride.
1165     auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1166     EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1167     EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1168     EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1169     EXPECT_EQ(1, costs.num_ops_total);
1170     EXPECT_FALSE(costs.inaccurate);
1171     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1172   }
1173 }
1174 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1175 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1176   auto predict_fused_bn = [this](const int n, const int in, const int c,
1177                                  const bool is_training) -> Costs {
1178     OpContext op_context = DescribeFusedBatchNorm(
1179         is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1180     return estimator_.PredictCosts(op_context);
1181   };
1182 
1183   {
1184     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1185     EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1186     EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1187     EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1188     EXPECT_EQ(1, costs.num_ops_total);
1189     EXPECT_FALSE(costs.inaccurate);
1190     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1191   }
1192 
1193   {
1194     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1195     EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1196     EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1197     EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1198     EXPECT_EQ(1, costs.num_ops_total);
1199     EXPECT_FALSE(costs.inaccurate);
1200     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1201   }
1202 
1203   {
1204     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1205     EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1206     EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1207     EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1208     EXPECT_EQ(1, costs.num_ops_total);
1209     EXPECT_FALSE(costs.inaccurate);
1210     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1211   }
1212 
1213   {
1214     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1215     EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1216     EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1217     EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1218     EXPECT_FALSE(costs.inaccurate);
1219     EXPECT_EQ(1, costs.num_ops_total);
1220     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1221   }
1222 }
1223 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1224 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1225   auto predict_fused_bn_grad = [this](const int n, const int in,
1226                                       const int c) -> Costs {
1227     OpContext op_context = DescribeFusedBatchNorm(
1228         /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1229     return estimator_.PredictCosts(op_context);
1230   };
1231 
1232   {
1233     auto costs = predict_fused_bn_grad(10, 20, 96);
1234     EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1235     EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1236     EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1237     EXPECT_EQ(1, costs.num_ops_total);
1238     EXPECT_FALSE(costs.inaccurate);
1239     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1240   }
1241 
1242   {
1243     auto costs = predict_fused_bn_grad(128, 7, 384);
1244     EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1245     EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1246     EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1247     EXPECT_EQ(1, costs.num_ops_total);
1248     EXPECT_FALSE(costs.inaccurate);
1249     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1250   }
1251 }
1252 
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1253 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1254   {
1255     TensorShapeProto x;
1256     x.set_unknown_rank(true);
1257     bool unknown_shapes = false;
1258     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1259     EXPECT_TRUE(unknown_shapes);
1260     ExpectTensorShape({1, 1, 1, 1}, y);
1261   }
1262 
1263   {
1264     TensorShapeProto x;
1265     x.set_unknown_rank(false);
1266     bool unknown_shapes = false;
1267     TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1268     EXPECT_FALSE(unknown_shapes);
1269     ExpectTensorShape({1}, y);
1270   }
1271 
1272   {
1273     TensorShapeProto x;
1274     x.set_unknown_rank(false);
1275     bool unknown_shapes = false;
1276     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1277     EXPECT_FALSE(unknown_shapes);
1278     ExpectTensorShape({1, 1}, y);
1279   }
1280 
1281   {
1282     TensorShapeProto x;
1283     x.set_unknown_rank(false);
1284     x.add_dim()->set_size(10);
1285     x.add_dim()->set_size(20);
1286     bool unknown_shapes = false;
1287     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1288     EXPECT_FALSE(unknown_shapes);
1289     ExpectTensorShape({10, 20}, y);
1290 
1291     unknown_shapes = false;
1292     TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1293     EXPECT_TRUE(unknown_shapes);
1294     EXPECT_EQ(4, z.dim_size());
1295     ExpectTensorShape({10, 20, 1, 1}, z);
1296   }
1297 
1298   {
1299     TensorShapeProto x;
1300     x.set_unknown_rank(false);
1301     x.add_dim()->set_size(10);
1302     x.add_dim()->set_size(20);
1303     x.add_dim()->set_size(-1);
1304     x.add_dim()->set_size(20);
1305     bool unknown_shapes = false;
1306     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1307     EXPECT_TRUE(unknown_shapes);
1308     ExpectTensorShape({10, 20, 1, 20}, y);
1309   }
1310 
1311   {
1312     TensorShapeProto x;
1313     x.set_unknown_rank(false);
1314     x.add_dim()->set_size(10);
1315     x.add_dim()->set_size(20);
1316     x.add_dim()->set_size(30);
1317     x.add_dim()->set_size(20);
1318     bool unknown_shapes = false;
1319     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1320     EXPECT_TRUE(unknown_shapes);
1321     ExpectTensorShape({10, 20}, y);
1322   }
1323 }
1324 
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1325 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1326   TestOpLevelCostEstimator estimator;
1327 
1328   // Compute limited.
1329   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1330                                      /*gb_per_sec=*/1));
1331   estimator.SetComputeMemoryOverlap(true);
1332   auto cost = estimator.PredictCosts(
1333       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1334   EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1335   EXPECT_EQ(cost.execution_time, cost.compute_time);
1336 
1337   estimator.SetComputeMemoryOverlap(false);
1338   cost = estimator.PredictCosts(
1339       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1340   EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1341   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1342                                      cost.intermediate_memory_time);
1343 
1344   // Memory limited.
1345   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1346                                      /*gb_per_sec=*/1));
1347   estimator.SetComputeMemoryOverlap(true);
1348   cost = estimator.PredictCosts(
1349       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1350   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1351   EXPECT_EQ(cost.execution_time, cost.memory_time);
1352 
1353   estimator.SetComputeMemoryOverlap(false);
1354   cost = estimator.PredictCosts(
1355       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1356   EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1357   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1358                                      cost.intermediate_memory_time);
1359 
1360   // Intermediate memory bandwidth limited.
1361   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1362                                      /*gb_per_sec=*/9999,
1363                                      /*intermediate_read_gb_per_sec=*/1,
1364                                      /*intermediate_write_gb_per_sec=*/1));
1365   estimator.SetComputeMemoryOverlap(true);
1366   cost = estimator.PredictCosts(
1367       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1368   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1369   EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1370 
1371   estimator.SetComputeMemoryOverlap(false);
1372   cost = estimator.PredictCosts(
1373       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1374   EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1375   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1376                                      cost.intermediate_memory_time);
1377 }
1378 }  // end namespace grappler
1379 }  // end namespace tensorflow
1380