1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/device_properties.pb.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
30 namespace {
31
32 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
33 // remove friend in the OpLevelCostEstimator class header.
34 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
35 public:
TestOpLevelCostEstimator()36 TestOpLevelCostEstimator() {
37 compute_memory_overlap_ = true;
38 device_info_ = DeviceInfo();
39 }
~TestOpLevelCostEstimator()40 ~TestOpLevelCostEstimator() override {}
41
SetDeviceInfo(const DeviceInfo & device_info)42 void SetDeviceInfo(const DeviceInfo& device_info) {
43 device_info_ = device_info;
44 }
45
SetComputeMemoryOverlap(bool value)46 void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
47
48 protected:
GetDeviceInfo(const DeviceProperties & device) const49 DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
50 return device_info_;
51 }
52
53 DeviceInfo device_info_;
54 };
55
56 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)57 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
58 auto input = op_info->add_inputs();
59 auto shape = input->mutable_shape();
60 auto shape_rows = shape->add_dim();
61 shape_rows->set_size(rows);
62 auto shape_columns = shape->add_dim();
63 shape_columns->set_size(columns);
64 input->set_dtype(DT_FLOAT);
65 }
66
SetCpuDevice(OpInfo * op_info)67 void SetCpuDevice(OpInfo* op_info) {
68 auto device = op_info->mutable_device();
69 device->set_type("CPU");
70 device->set_num_cores(10);
71 device->set_bandwidth(10000000); // 10000000 KB/s = 10 GB/s
72 device->set_frequency(1000); // 1000 Mhz = 1 GHz
73 }
74
75 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)76 OpContext DescribeMatMul(int m, int n, int l, int k) {
77 OpContext op_context;
78 SetCpuDevice(&op_context.op_info);
79 op_context.op_info.set_op("MatMul");
80
81 DescribeMatrix(m, l, &op_context.op_info);
82 DescribeMatrix(k, n, &op_context.op_info);
83 return op_context;
84 }
85
86 // Wrangles the minimum number of proto fields to set up an input of
87 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)88 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
89 OpInfo* op_info) {
90 auto input = op_info->add_inputs();
91 input->set_dtype(dtype);
92 auto shape = input->mutable_shape();
93 for (auto d : dims) {
94 shape->add_dim()->set_size(d);
95 }
96 }
97
98 // Wrangles the minimum number of proto fields to set up an output of
99 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)100 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
101 OpInfo* op_info) {
102 auto output = op_info->add_outputs();
103 output->set_dtype(dtype);
104 auto shape = output->mutable_shape();
105 for (auto d : dims) {
106 shape->add_dim()->set_size(d);
107 }
108 }
109
110 // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)111 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
112 const std::vector<int>& dims_b) {
113 OpContext op_context;
114 SetCpuDevice(&op_context.op_info);
115 op_context.op_info.set_op("BatchMatMul");
116
117 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
118 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
119 return op_context;
120 }
121
122 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)123 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
124 const std::vector<int>& dims_b,
125 const std::vector<int>& dims_out) {
126 OpContext op_context;
127 SetCpuDevice(&op_context.op_info);
128 op_context.op_info.set_op("SparseTensorDenseMatMul");
129
130 DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
131 DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
132 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
133 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
134 DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
135 return op_context;
136 }
137
138 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
139 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)140 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
141 auto shape = tensor->mutable_shape();
142 shape->add_dim()->set_size(dim0);
143 tensor->set_dtype(DT_FLOAT);
144 }
145
146 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
147 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)148 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
149 OpInfo::TensorProperties* tensor) {
150 auto shape = tensor->mutable_shape();
151 shape->add_dim()->set_size(dim0);
152 shape->add_dim()->set_size(dim1);
153 shape->add_dim()->set_size(dim2);
154 shape->add_dim()->set_size(dim3);
155 tensor->set_dtype(DT_FLOAT);
156 }
157
158 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
159 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
160 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)161 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
162 int kx, int ky, int oz) {
163 OpContext op_context;
164 SetCpuDevice(&op_context.op_info);
165 op_context.op_info.set_op("Conv2D");
166
167 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
168 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
169
170 return op_context;
171 }
172
173 // Describe DepthwiseConvolution constructs an OpContext for a
174 // DepthwiseConv2dNative applied to an input
175 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
176 // (kx, ky, iz2, cm). cm is channel multiplier
177
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)178 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
179 int iz2, int kx, int ky, int cm) {
180 OpContext op_context;
181 SetCpuDevice(&op_context.op_info);
182 op_context.op_info.set_op("DepthwiseConv2dNative");
183
184 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
185 DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
186
187 return op_context;
188 }
189
190 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
191 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
192 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
193 // bias tensor with shape (oz), a side input tensor with shape
194 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
195 // shape (1).
196 //
197 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)198 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
199 int iz2, int kx, int ky, int ox,
200 int oy, int oz, bool has_side_input,
201 const string& data_format,
202 const string& filter_format) {
203 OpContext op_context;
204 SetCpuDevice(&op_context.op_info);
205 op_context.op_info.set_op("FusedConv2DBiasActivation");
206 auto* attr_data_format = op_context.op_info.mutable_attr();
207 SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
208 auto* attr_filter_format = op_context.op_info.mutable_attr();
209 SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
210 if (data_format == "NHWC") {
211 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
212 } else {
213 // Use the NCHW format.
214 DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
215 }
216 if (filter_format == "HWIO") {
217 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
218 } else {
219 // Use the OIHW format.
220 DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
221 }
222 DescribeTensor1D(oz, op_context.op_info.add_inputs());
223
224 // Add the side_input, if any.
225 auto side_input = op_context.op_info.add_inputs();
226 if (has_side_input) {
227 if (data_format == "NHWC") {
228 DescribeTensor4D(batch, ox, oy, oz, side_input);
229 } else {
230 DescribeTensor4D(batch, oz, ox, oy, side_input);
231 }
232 }
233
234 // Add the scaling tensors.
235 DescribeTensor1D(1, op_context.op_info.add_inputs());
236 DescribeTensor1D(1, op_context.op_info.add_inputs());
237
238 return op_context;
239 }
240
241 // DescribeUnaryOp constructs an OpContext for the given operation applied to
242 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)243 OpContext DescribeUnaryOp(const string& op, int size1) {
244 OpContext op_context;
245 SetCpuDevice(&op_context.op_info);
246 op_context.op_info.set_op(op);
247
248 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
249 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
250
251 return op_context;
252 }
253
254 // DescribeBinaryOp constructs an OpContext for the given operation applied to
255 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
256 // (2 * size1, size2, 1, 1).
257 //
258 // The choice of dimension here is arbitrary, and is used strictly to test the
259 // cost model for applying elementwise operations to tensors with unequal
260 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)261 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
262 OpContext op_context;
263 SetCpuDevice(&op_context.op_info);
264 op_context.op_info.set_op(op);
265
266 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
267 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
268 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
269
270 return op_context;
271 }
272
273 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
274 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
275 // according to the constraint that the bias must be 1D with size equal to that
276 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)277 OpContext DescribeBiasAdd(int size1, int size2) {
278 OpContext op_context;
279 SetCpuDevice(&op_context.op_info);
280 op_context.op_info.set_op("BiasAdd");
281
282 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
283 DescribeTensor1D(size1, op_context.op_info.add_inputs());
284 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
285
286 return op_context;
287 }
288
GetOutputSize(const int x,const int k,const int s,const string & padding)289 int GetOutputSize(const int x, const int k, const int s,
290 const string& padding) {
291 if (padding == "SAME") {
292 return (x + s - 1) / s;
293 } else {
294 return (x - k + s) / s;
295 }
296 }
297
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)298 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
299 const std::vector<int>& ksize,
300 const std::vector<int>& strides,
301 const string& data_format,
302 const string& padding) {
303 // h, w, and c indices: default with NHWC.
304 int h_index = 1;
305 int w_index = 2;
306 int c_index = 3;
307 if (data_format == "NCHW") {
308 h_index = 2;
309 w_index = 3;
310 c_index = 1;
311 }
312 // Extract parameters.
313 int n = input[0];
314 int h = input[h_index];
315 int w = input[w_index];
316 int c = input[c_index];
317 int sx = strides[h_index];
318 int sy = strides[w_index];
319 int kx = ksize[h_index];
320 int ky = ksize[w_index];
321
322 // Output activation size: default with VALID padding.
323 int ho = GetOutputSize(h, kx, sx, padding);
324 int wo = GetOutputSize(w, ky, sy, padding);
325
326 std::vector<int> output;
327 if (data_format == "NHWC") {
328 output = {n, ho, wo, c};
329 } else {
330 output = {n, c, ho, wo};
331 }
332 return output;
333 }
334
335 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64> & shape,const std::vector<int64> values,const bool tensor_content,TensorProto * tensor_proto)336 void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
337 const std::vector<int64> values, const bool tensor_content,
338 TensorProto* tensor_proto) {
339 tensor_proto->Clear();
340 TensorProto temp_tensor_proto;
341 temp_tensor_proto.set_dtype(dtype);
342 for (const auto& x : shape) {
343 temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
344 }
345 for (const auto& x : values) {
346 if (dtype == DT_INT64) {
347 temp_tensor_proto.add_int64_val(x);
348 } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
349 dtype == DT_UINT8) {
350 temp_tensor_proto.add_int_val(x);
351 } else if (dtype == DT_UINT32) {
352 temp_tensor_proto.add_uint32_val(x);
353 } else if (dtype == DT_UINT64) {
354 temp_tensor_proto.add_uint64_val(x);
355 } else {
356 CHECK(false) << "Unsupported dtype: " << dtype;
357 }
358 }
359 Tensor tensor(dtype);
360 CHECK(tensor.FromProto(temp_tensor_proto));
361 if (tensor_content) {
362 tensor.AsProtoTensorContent(tensor_proto);
363 } else {
364 tensor.AsProtoField(tensor_proto);
365 }
366 }
367
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)368 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
369 const std::vector<int>& ksize,
370 const std::vector<int>& strides,
371 const string& data_format, const string& padding) {
372 OpContext op_context;
373 auto& op_info = op_context.op_info;
374 SetCpuDevice(&op_info);
375 op_info.set_op(op_name);
376
377 const std::vector<int> y =
378 GetPoolingOutputSize(x, ksize, strides, data_format, padding);
379 if (op_name == "AvgPool" || op_name == "MaxPool") {
380 // input: x, output: y.
381 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
382 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
383 } else if (op_name == "AvgPoolGrad") {
384 // input: x's shape, y_grad, output: x_grad.
385 DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
386 auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
387 GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
388 /*tensor_content=*/false, tensor_proto);
389 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
390 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
391 } else if (op_name == "MaxPoolGrad") {
392 // input: x, y, y_grad, output: x_grad.
393 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
394 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
395 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
396 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
397 }
398 auto* attr = op_info.mutable_attr();
399 SetAttrValue(data_format, &(*attr)["data_format"]);
400 SetAttrValue(padding, &(*attr)["padding"]);
401 SetAttrValue(strides, &(*attr)["strides"]);
402 SetAttrValue(ksize, &(*attr)["ksize"]);
403 return op_context;
404 }
405
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)406 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
407 const std::vector<int>& x,
408 const string& data_format) {
409 // First, get MaxPool op info with unit stride and unit window.
410 OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
411 {1, 1, 1, 1}, data_format, "SAME");
412 auto& op_info = op_context.op_info;
413 // Override op name.
414 if (is_grad) {
415 op_info.set_op("FusedBatchNormGrad");
416 } else {
417 op_info.set_op("FusedBatchNorm");
418 }
419
420 // Add additional input output tensors.
421 if (is_grad) {
422 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
423 }
424 int num_1d_inputs = is_grad ? 3 : 4;
425 for (int i = 0; i < num_1d_inputs; i++) {
426 auto* tensor = op_info.add_inputs();
427 auto* shape = tensor->mutable_shape();
428 shape->add_dim()->set_size(x[3]);
429 tensor->set_dtype(DT_FLOAT);
430 }
431 for (int i = 0; i < 4; i++) {
432 auto* tensor = op_info.add_outputs();
433 auto* shape = tensor->mutable_shape();
434 shape->add_dim()->set_size(x[3]);
435 tensor->set_dtype(DT_FLOAT);
436 }
437
438 // Delete unnecessary attr.
439 auto* attr = op_context.op_info.mutable_attr();
440 attr->erase("ksize");
441 attr->erase("strides");
442 attr->erase("padding");
443
444 // Additional attrs for FusedBatchNorm.
445 SetAttrValue(is_training, &(*attr)["is_training"]);
446
447 return op_context;
448 }
449 } // namespace
450
451 class OpLevelCostEstimatorTest : public ::testing::Test {
452 protected:
PredictCosts(const OpContext & op_context) const453 Costs PredictCosts(const OpContext& op_context) const {
454 return estimator_.PredictCosts(op_context);
455 }
456
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const457 int64 CountMatMulOperations(const OpInfo& op_info,
458 bool* found_unknown_shapes) const {
459 return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
460 }
461
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const462 int64 CountBatchMatMulOperations(const OpInfo& op_info,
463 bool* found_unknown_shapes) const {
464 return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
465 }
466
SetComputeMemoryOverlap(bool value)467 void SetComputeMemoryOverlap(bool value) {
468 estimator_.compute_memory_overlap_ = value;
469 }
470
ValidateOpDimensionsFromImputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)471 void ValidateOpDimensionsFromImputs(const int n, const int h, const int w,
472 const int c, const int kx, const int ky,
473 const int sx, const int sy,
474 const string& data_format,
475 const string& padding) {
476 OpContext op_context;
477 int ho;
478 int wo;
479 if (data_format == "NHWC") {
480 op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
481 {1, sx, sy, 1}, "NHWC", padding);
482 ho = op_context.op_info.outputs(0).shape().dim(1).size();
483 wo = op_context.op_info.outputs(0).shape().dim(2).size();
484 } else {
485 op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
486 {1, 1, sx, sy}, "NCHW", padding);
487 ho = op_context.op_info.outputs(0).shape().dim(2).size();
488 wo = op_context.op_info.outputs(0).shape().dim(3).size();
489 }
490
491 bool found_unknown_shapes;
492 auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
493 op_context.op_info.inputs(0).shape(), op_context.op_info,
494 &found_unknown_shapes);
495 Padding padding_enum;
496 if (padding == "VALID") {
497 padding_enum = Padding::VALID;
498 } else {
499 padding_enum = Padding::SAME;
500 }
501 EXPECT_EQ(n, dims.batch);
502 EXPECT_EQ(h, dims.ix);
503 EXPECT_EQ(w, dims.iy);
504 EXPECT_EQ(c, dims.iz);
505 EXPECT_EQ(kx, dims.kx);
506 EXPECT_EQ(ky, dims.ky);
507 EXPECT_EQ(sx, dims.sx);
508 EXPECT_EQ(sy, dims.sy);
509 EXPECT_EQ(ho, dims.ox);
510 EXPECT_EQ(wo, dims.oy);
511 EXPECT_EQ(c, dims.oz);
512 EXPECT_EQ(padding_enum, dims.padding);
513 }
514
515 OpLevelCostEstimator estimator_;
516 };
517
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)518 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
519 OpContext op_context;
520 SetCpuDevice(&op_context.op_info);
521 std::unordered_set<string> persisent_ops = {
522 "Const", "Variable", "VariableV2", "AutoReloadVariable",
523 "VarHandleOp", "ReadVariableOp",
524 };
525 // Minmum cost for all persistent ops.
526 for (const auto& op : persisent_ops) {
527 op_context.op_info.set_op(op);
528 auto cost = estimator_.PredictCosts(op_context);
529 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
530 EXPECT_EQ(Costs::Duration(1), cost.compute_time);
531 EXPECT_EQ(Costs::Duration(1), cost.execution_time);
532 EXPECT_EQ(1, cost.num_ops_total);
533 EXPECT_FALSE(cost.inaccurate);
534 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
535 }
536 }
537
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)538 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
539 OpContext op_context;
540 SetCpuDevice(&op_context.op_info);
541 op_context.op_info.set_op("Gather");
542
543 // Huge first input shouldn't affect Gather execution and memory costs.
544 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
545 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
546 DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
547
548 auto cost = estimator_.PredictCosts(op_context);
549 EXPECT_EQ(Costs::Duration(130), cost.memory_time);
550 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
551 EXPECT_EQ(Costs::Duration(146), cost.execution_time);
552 EXPECT_EQ(1, cost.num_ops_total);
553 EXPECT_FALSE(cost.inaccurate);
554 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
555 }
556
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)557 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
558 OpContext op_context;
559 SetCpuDevice(&op_context.op_info);
560 op_context.op_info.set_op("Gather");
561
562 // Huge first input shouldn't affect Gather execution and memory costs.
563 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
564 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
565
566 auto cost = estimator_.PredictCosts(op_context);
567 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
568 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
569 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
570 EXPECT_EQ(1, cost.num_ops_total);
571 EXPECT_TRUE(cost.inaccurate);
572 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
573 }
574
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)575 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
576 OpContext op_context;
577 SetCpuDevice(&op_context.op_info);
578 op_context.op_info.set_op("Slice");
579
580 // Huge first input shouldn't affect Slice execution and memory costs.
581 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
582 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
583 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
584 DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
585
586 auto cost = estimator_.PredictCosts(op_context);
587 EXPECT_EQ(Costs::Duration(81), cost.memory_time);
588 EXPECT_EQ(Costs::Duration(10), cost.compute_time);
589 EXPECT_EQ(Costs::Duration(91), cost.execution_time);
590 EXPECT_EQ(1, cost.num_ops_total);
591 EXPECT_FALSE(cost.inaccurate);
592 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
593 }
594
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)595 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
596 auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
597 EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
598 EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
599 EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
600 EXPECT_EQ(1, cost.num_ops_total);
601 EXPECT_FALSE(cost.inaccurate);
602 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
603 }
604
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)605 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
606 auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
607 EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
608 EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
609 EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
610 EXPECT_EQ(1, cost.num_ops_total);
611 EXPECT_FALSE(cost.inaccurate);
612 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
613 }
614
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)615 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
616 auto cost =
617 PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
618 EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
619 EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
620 EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
621 EXPECT_EQ(1, cost.num_ops_total);
622 EXPECT_FALSE(cost.inaccurate);
623 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
624 }
625
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)626 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
627 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
628 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
629 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
630 EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
631 EXPECT_EQ(1, cost.num_ops_total);
632 EXPECT_TRUE(cost.inaccurate);
633 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
634 }
635
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)636 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
637 SetComputeMemoryOverlap(true);
638 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
639 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
640 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
641 EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200)
642 EXPECT_EQ(1, cost.num_ops_total);
643 EXPECT_TRUE(cost.inaccurate);
644 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
645 SetComputeMemoryOverlap(false); // Set it back to default.
646 }
647
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)648 TEST_F(OpLevelCostEstimatorTest,
649 FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
650 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
651 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
652 "NCHW", "HWIO"));
653 EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
654 EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
655 EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
656 EXPECT_EQ(1, cost.num_ops_total);
657 EXPECT_FALSE(cost.inaccurate);
658 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
659 }
660
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)661 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
662 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
663 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
664 "NCHW", "HWIO"));
665 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
666 EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
667 EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
668 EXPECT_EQ(1, cost.num_ops_total);
669 EXPECT_FALSE(cost.inaccurate);
670 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
671 }
672
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)673 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
674 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
675 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
676 "NCHW", "OIHW"));
677 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
678 EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
679 EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
680 EXPECT_EQ(1, cost.num_ops_total);
681 EXPECT_FALSE(cost.inaccurate);
682 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
683 }
684
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)685 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
686 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
687 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
688 "NHWC", "HWIO"));
689 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
690 EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
691 EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
692 EXPECT_EQ(1, cost.num_ops_total);
693 EXPECT_FALSE(cost.inaccurate);
694 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
695 }
696
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)697 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
698 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
699 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
700 "NHWC", "OIHW"));
701 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
702 EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
703 EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
704 EXPECT_EQ(1, cost.num_ops_total);
705 EXPECT_FALSE(cost.inaccurate);
706 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
707 }
708
709 // TODO(yaozhang): Update once NCHW_VECT_C is supported.
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)710 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
711 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
712 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
713 "NCHW_VECT_C", "OIHW"));
714 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
715 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
716 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
717 EXPECT_EQ(1, cost.num_ops_total);
718 EXPECT_TRUE(cost.inaccurate);
719 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
720 }
721
722 // TODO(yaozhang): Update once OIHW_VECT_I is supported.
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)723 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
724 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
725 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
726 "NCHW", "OIHW_VECT_I"));
727 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
728 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
729 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
730 EXPECT_EQ(1, cost.num_ops_total);
731 EXPECT_TRUE(cost.inaccurate);
732 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
733 }
734
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)735 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
736 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
737 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
738 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
739 EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
740 EXPECT_EQ(1, cost.num_ops_total);
741 EXPECT_FALSE(cost.inaccurate);
742 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
743 }
744
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)745 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
746 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
747 EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
748 EXPECT_EQ(Costs::Duration(400), cost.compute_time);
749 EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
750 EXPECT_EQ(1, cost.num_ops_total);
751 EXPECT_FALSE(cost.inaccurate);
752 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
753 }
754
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)755 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
756 auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
757 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
758 EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
759 EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
760 EXPECT_EQ(1, cost.num_ops_total);
761 EXPECT_FALSE(cost.inaccurate);
762 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
763 }
764
TEST_F(OpLevelCostEstimatorTest,ReluExecutionTime)765 TEST_F(OpLevelCostEstimatorTest, ReluExecutionTime) {
766 auto cost = PredictCosts(DescribeUnaryOp("Relu", 1000));
767 EXPECT_EQ(Costs::Duration(800), cost.memory_time);
768 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
769 EXPECT_EQ(Costs::Duration(900), cost.execution_time);
770 EXPECT_EQ(1, cost.num_ops_total);
771 EXPECT_FALSE(cost.inaccurate);
772 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
773 }
774
TEST_F(OpLevelCostEstimatorTest,CastExecutionTime)775 TEST_F(OpLevelCostEstimatorTest, CastExecutionTime) {
776 auto cost = PredictCosts(DescribeUnaryOp("Cast", 1000));
777 EXPECT_EQ(Costs::Duration(800), cost.memory_time);
778 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
779 EXPECT_EQ(Costs::Duration(900), cost.execution_time);
780 EXPECT_EQ(1, cost.num_ops_total);
781 EXPECT_FALSE(cost.inaccurate);
782 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
783 }
784
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)785 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
786 {
787 auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
788 EXPECT_EQ(1, cost.num_ops_total);
789 EXPECT_FALSE(cost.inaccurate);
790 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
791 }
792 {
793 auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
794 EXPECT_EQ(1, cost.num_ops_total);
795 EXPECT_TRUE(cost.inaccurate);
796 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
797 }
798 {
799 auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
800 EXPECT_EQ(1, cost.num_ops_total);
801 EXPECT_TRUE(cost.inaccurate);
802 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
803 }
804 {
805 auto cost =
806 PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
807 EXPECT_EQ(1, cost.num_ops_total);
808 EXPECT_FALSE(cost.inaccurate);
809 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
810 }
811 {
812 auto cost =
813 PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
814 EXPECT_EQ(1, cost.num_ops_total);
815 EXPECT_TRUE(cost.inaccurate);
816 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
817 }
818 }
819
TEST_F(OpLevelCostEstimatorTest,BatchMatMul)820 TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
821 {
822 auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
823 EXPECT_EQ(1, cost.num_ops_total);
824 EXPECT_TRUE(cost.inaccurate);
825 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
826 }
827 {
828 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
829 EXPECT_EQ(1, cost.num_ops_total);
830 EXPECT_TRUE(cost.inaccurate);
831 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
832 }
833 {
834 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
835 EXPECT_EQ(1, cost.num_ops_total);
836 EXPECT_FALSE(cost.inaccurate);
837 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
838 }
839 {
840 auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
841 EXPECT_EQ(1, cost.num_ops_total);
842 EXPECT_FALSE(cost.inaccurate);
843 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
844 }
845 {
846 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
847 EXPECT_EQ(1, cost.num_ops_total);
848 EXPECT_FALSE(cost.inaccurate);
849 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
850 }
851 bool matmul_inaccurate = false;
852 bool batch_matmul_inaccurate = false;
853 EXPECT_EQ(
854 CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
855 &matmul_inaccurate),
856 CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
857 &batch_matmul_inaccurate));
858 EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
859 EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
860 &matmul_inaccurate),
861 CountBatchMatMulOperations(
862 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
863 &batch_matmul_inaccurate));
864 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
865 EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
866 &matmul_inaccurate),
867 CountBatchMatMulOperations(
868 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
869 &batch_matmul_inaccurate));
870 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
871 }
872
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)873 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
874 // Unknown shape cases
875 {
876 auto cost =
877 PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
878 EXPECT_EQ(1, cost.num_ops_total);
879 EXPECT_TRUE(cost.inaccurate);
880 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
881 }
882 {
883 auto cost =
884 PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
885 EXPECT_EQ(1, cost.num_ops_total);
886 EXPECT_TRUE(cost.inaccurate);
887 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
888 }
889 {
890 auto cost =
891 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
892 EXPECT_EQ(1, cost.num_ops_total);
893 EXPECT_TRUE(cost.inaccurate);
894 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
895 }
896 {
897 auto cost =
898 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
899 EXPECT_EQ(1, cost.num_ops_total);
900 EXPECT_TRUE(cost.inaccurate);
901 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
902 }
903 // Known shape cases
904 {
905 auto cost = PredictCosts(
906 DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
907 EXPECT_EQ(1, cost.num_ops_total);
908 EXPECT_FALSE(cost.inaccurate);
909 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
910 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
911 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
912 }
913 {
914 // Same cost as above case because cost does not depend on k_dim
915 auto cost = PredictCosts(
916 DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
917 EXPECT_EQ(1, cost.num_ops_total);
918 EXPECT_FALSE(cost.inaccurate);
919 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
920 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
921 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
922 }
923 }
924
ExpectTensorShape(const std::vector<int64> & expected,const TensorShapeProto & tensor_shape_proto)925 void ExpectTensorShape(const std::vector<int64>& expected,
926 const TensorShapeProto& tensor_shape_proto) {
927 TensorShape tensor_shape_expected(expected);
928 TensorShape tensor_shape(tensor_shape_proto);
929
930 LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString();
931 LOG(INFO) << "TensorShape: " << tensor_shape.DebugString();
932 EXPECT_TRUE(tensor_shape_expected == tensor_shape);
933 }
934
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)935 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
936 TensorProto tensor_proto;
937 TensorShapeProto tensor_shape_proto;
938
939 // Dimension larger than max value; should fail while converting to
940 // Tensor class.
941 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
942 EXPECT_FALSE(
943 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
944
945 tensor_proto.Clear();
946 // Expect only 1D shape.
947 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
948 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
949 EXPECT_FALSE(
950 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
951
952 // Expect only handle integer data types.
953 GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
954 EXPECT_FALSE(
955 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
956
957 // Check GetTensorShapeProtoFromTensorProto() returns correct values.
958 {
959 std::vector<int64> shape_expected = {10, 20, 30, 40};
960 GetTensorProto(DT_INT32, {4}, shape_expected,
961 /*tensor_content=*/false, &tensor_proto);
962 EXPECT_TRUE(
963 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
964 ExpectTensorShape(shape_expected, tensor_shape_proto);
965 }
966
967 {
968 std::vector<int64> shape_expected = {40, 20, 90, 40};
969 GetTensorProto(DT_INT64, {4}, shape_expected,
970 /*tensor_content=*/false, &tensor_proto);
971 EXPECT_TRUE(
972 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
973 ExpectTensorShape(shape_expected, tensor_shape_proto);
974 }
975
976 {
977 std::vector<int64> shape_expected = {10, 20, 30, 40};
978 GetTensorProto(DT_INT32, {4}, shape_expected,
979 /*tensor_content=*/true, &tensor_proto);
980 EXPECT_TRUE(
981 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
982 ExpectTensorShape(shape_expected, tensor_shape_proto);
983 }
984
985 {
986 std::vector<int64> shape_expected = {40, 20, 90, 40};
987 GetTensorProto(DT_INT64, {4}, shape_expected,
988 /*tensor_content=*/true, &tensor_proto);
989 EXPECT_TRUE(
990 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
991 ExpectTensorShape(shape_expected, tensor_shape_proto);
992 }
993 }
994
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)995 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
996 std::vector<string> paddings = {"VALID", "SAME"};
997 std::vector<string> formats = {"NHWC", "NCHW"};
998 for (const auto& p : paddings) {
999 for (const auto& f : formats) {
1000 // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1001 ValidateOpDimensionsFromImputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1002 ValidateOpDimensionsFromImputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1003 ValidateOpDimensionsFromImputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1004 ValidateOpDimensionsFromImputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1005 }
1006 }
1007 }
1008
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1009 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1010 auto predict_max_pool = [this](const int n, const int in, const int c,
1011 const int k, const int s,
1012 const string& padding) -> Costs {
1013 OpContext op_context = DescribePoolingOp(
1014 "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1015 return estimator_.PredictCosts(op_context);
1016 };
1017
1018 {
1019 // Typical 3xz3 window with 2x2 stride.
1020 auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1021 EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1022 EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1023 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1024 EXPECT_EQ(1, costs.num_ops_total);
1025 EXPECT_FALSE(costs.inaccurate);
1026 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1027 }
1028 {
1029 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1030 auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1031 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1032 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1033 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1034 EXPECT_EQ(1, costs.num_ops_total);
1035 EXPECT_FALSE(costs.inaccurate);
1036 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1037 }
1038 {
1039 // 2x2 window with 3x3 stride.
1040 auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1041 EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1042 EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1043 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1044 EXPECT_EQ(1, costs.num_ops_total);
1045 EXPECT_FALSE(costs.inaccurate);
1046 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1047 }
1048 }
1049
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1050 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1051 auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1052 const int k, const int s,
1053 const string& padding) -> Costs {
1054 OpContext op_context =
1055 DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1056 {1, s, s, 1}, "NHWC", padding);
1057 return estimator_.PredictCosts(op_context);
1058 };
1059
1060 {
1061 // Typical 3x3 window with 2x2 stride.
1062 auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1063 EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1064 EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1065 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1066 EXPECT_EQ(1, costs.num_ops_total);
1067 EXPECT_FALSE(costs.inaccurate);
1068 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1069 }
1070 {
1071 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1072 auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1073 EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1074 EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1075 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1076 EXPECT_EQ(1, costs.num_ops_total);
1077 EXPECT_FALSE(costs.inaccurate);
1078 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1079 }
1080 {
1081 // 2x2 window with 3x3 stride.
1082 auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1083 EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1084 EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1085 EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1086 EXPECT_EQ(1, costs.num_ops_total);
1087 EXPECT_FALSE(costs.inaccurate);
1088 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1089 }
1090 }
1091
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1092 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1093 auto predict_avg_pool = [this](const int n, const int in, const int c,
1094 const int k, const int s,
1095 const string& padding) -> Costs {
1096 OpContext op_context = DescribePoolingOp(
1097 "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1098 return estimator_.PredictCosts(op_context);
1099 };
1100
1101 {
1102 // Typical 3x3 window with 2x2 stride.
1103 auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1104 EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1105 EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1106 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1107 EXPECT_EQ(1, costs.num_ops_total);
1108 EXPECT_FALSE(costs.inaccurate);
1109 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1110 }
1111 {
1112 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1113 auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1114 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1115 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1116 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1117 EXPECT_EQ(1, costs.num_ops_total);
1118 EXPECT_FALSE(costs.inaccurate);
1119 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1120 }
1121 {
1122 // 2x2 window with 3x3 stride.
1123 auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1124 EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1125 EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1126 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1127 EXPECT_EQ(1, costs.num_ops_total);
1128 EXPECT_FALSE(costs.inaccurate);
1129 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1130 }
1131 }
1132
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1133 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1134 auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1135 const int k, const int s,
1136 const string& padding) -> Costs {
1137 OpContext op_context =
1138 DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1139 {1, s, s, 1}, "NHWC", padding);
1140 return estimator_.PredictCosts(op_context);
1141 };
1142
1143 {
1144 // Typical 3xz3 window with 2x2 stride.
1145 auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1146 EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1147 EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1148 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1149 EXPECT_EQ(1, costs.num_ops_total);
1150 EXPECT_FALSE(costs.inaccurate);
1151 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1152 }
1153 {
1154 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1155 auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1156 EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1157 EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1158 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1159 EXPECT_EQ(1, costs.num_ops_total);
1160 EXPECT_FALSE(costs.inaccurate);
1161 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1162 }
1163 {
1164 // 2x2 window with 3x3 stride.
1165 auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1166 EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1167 EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1168 EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1169 EXPECT_EQ(1, costs.num_ops_total);
1170 EXPECT_FALSE(costs.inaccurate);
1171 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1172 }
1173 }
1174
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1175 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1176 auto predict_fused_bn = [this](const int n, const int in, const int c,
1177 const bool is_training) -> Costs {
1178 OpContext op_context = DescribeFusedBatchNorm(
1179 is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1180 return estimator_.PredictCosts(op_context);
1181 };
1182
1183 {
1184 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1185 EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1186 EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1187 EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1188 EXPECT_EQ(1, costs.num_ops_total);
1189 EXPECT_FALSE(costs.inaccurate);
1190 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1191 }
1192
1193 {
1194 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1195 EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1196 EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1197 EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1198 EXPECT_EQ(1, costs.num_ops_total);
1199 EXPECT_FALSE(costs.inaccurate);
1200 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1201 }
1202
1203 {
1204 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1205 EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1206 EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1207 EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1208 EXPECT_EQ(1, costs.num_ops_total);
1209 EXPECT_FALSE(costs.inaccurate);
1210 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1211 }
1212
1213 {
1214 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1215 EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1216 EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1217 EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1218 EXPECT_FALSE(costs.inaccurate);
1219 EXPECT_EQ(1, costs.num_ops_total);
1220 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1221 }
1222 }
1223
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1224 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1225 auto predict_fused_bn_grad = [this](const int n, const int in,
1226 const int c) -> Costs {
1227 OpContext op_context = DescribeFusedBatchNorm(
1228 /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1229 return estimator_.PredictCosts(op_context);
1230 };
1231
1232 {
1233 auto costs = predict_fused_bn_grad(10, 20, 96);
1234 EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1235 EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1236 EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1237 EXPECT_EQ(1, costs.num_ops_total);
1238 EXPECT_FALSE(costs.inaccurate);
1239 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1240 }
1241
1242 {
1243 auto costs = predict_fused_bn_grad(128, 7, 384);
1244 EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1245 EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1246 EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1247 EXPECT_EQ(1, costs.num_ops_total);
1248 EXPECT_FALSE(costs.inaccurate);
1249 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1250 }
1251 }
1252
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1253 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1254 {
1255 TensorShapeProto x;
1256 x.set_unknown_rank(true);
1257 bool unknown_shapes = false;
1258 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1259 EXPECT_TRUE(unknown_shapes);
1260 ExpectTensorShape({1, 1, 1, 1}, y);
1261 }
1262
1263 {
1264 TensorShapeProto x;
1265 x.set_unknown_rank(false);
1266 bool unknown_shapes = false;
1267 TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1268 EXPECT_FALSE(unknown_shapes);
1269 ExpectTensorShape({1}, y);
1270 }
1271
1272 {
1273 TensorShapeProto x;
1274 x.set_unknown_rank(false);
1275 bool unknown_shapes = false;
1276 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1277 EXPECT_FALSE(unknown_shapes);
1278 ExpectTensorShape({1, 1}, y);
1279 }
1280
1281 {
1282 TensorShapeProto x;
1283 x.set_unknown_rank(false);
1284 x.add_dim()->set_size(10);
1285 x.add_dim()->set_size(20);
1286 bool unknown_shapes = false;
1287 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1288 EXPECT_FALSE(unknown_shapes);
1289 ExpectTensorShape({10, 20}, y);
1290
1291 unknown_shapes = false;
1292 TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1293 EXPECT_TRUE(unknown_shapes);
1294 EXPECT_EQ(4, z.dim_size());
1295 ExpectTensorShape({10, 20, 1, 1}, z);
1296 }
1297
1298 {
1299 TensorShapeProto x;
1300 x.set_unknown_rank(false);
1301 x.add_dim()->set_size(10);
1302 x.add_dim()->set_size(20);
1303 x.add_dim()->set_size(-1);
1304 x.add_dim()->set_size(20);
1305 bool unknown_shapes = false;
1306 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1307 EXPECT_TRUE(unknown_shapes);
1308 ExpectTensorShape({10, 20, 1, 20}, y);
1309 }
1310
1311 {
1312 TensorShapeProto x;
1313 x.set_unknown_rank(false);
1314 x.add_dim()->set_size(10);
1315 x.add_dim()->set_size(20);
1316 x.add_dim()->set_size(30);
1317 x.add_dim()->set_size(20);
1318 bool unknown_shapes = false;
1319 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1320 EXPECT_TRUE(unknown_shapes);
1321 ExpectTensorShape({10, 20}, y);
1322 }
1323 }
1324
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1325 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1326 TestOpLevelCostEstimator estimator;
1327
1328 // Compute limited.
1329 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1330 /*gb_per_sec=*/1));
1331 estimator.SetComputeMemoryOverlap(true);
1332 auto cost = estimator.PredictCosts(
1333 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1334 EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1335 EXPECT_EQ(cost.execution_time, cost.compute_time);
1336
1337 estimator.SetComputeMemoryOverlap(false);
1338 cost = estimator.PredictCosts(
1339 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1340 EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1341 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1342 cost.intermediate_memory_time);
1343
1344 // Memory limited.
1345 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1346 /*gb_per_sec=*/1));
1347 estimator.SetComputeMemoryOverlap(true);
1348 cost = estimator.PredictCosts(
1349 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1350 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1351 EXPECT_EQ(cost.execution_time, cost.memory_time);
1352
1353 estimator.SetComputeMemoryOverlap(false);
1354 cost = estimator.PredictCosts(
1355 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1356 EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1357 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1358 cost.intermediate_memory_time);
1359
1360 // Intermediate memory bandwidth limited.
1361 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1362 /*gb_per_sec=*/9999,
1363 /*intermediate_read_gb_per_sec=*/1,
1364 /*intermediate_write_gb_per_sec=*/1));
1365 estimator.SetComputeMemoryOverlap(true);
1366 cost = estimator.PredictCosts(
1367 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1368 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1369 EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1370
1371 estimator.SetComputeMemoryOverlap(false);
1372 cost = estimator.PredictCosts(
1373 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1374 EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1375 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1376 cost.intermediate_memory_time);
1377 }
1378 } // end namespace grappler
1379 } // end namespace tensorflow
1380