1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/device_properties.pb.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
33 namespace {
34
35 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
36 // remove friend in the OpLevelCostEstimator class header.
37 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
38 public:
TestOpLevelCostEstimator()39 TestOpLevelCostEstimator() {
40 compute_memory_overlap_ = true;
41 device_info_ = DeviceInfo();
42 }
~TestOpLevelCostEstimator()43 ~TestOpLevelCostEstimator() override {}
44
SetDeviceInfo(const DeviceInfo & device_info)45 void SetDeviceInfo(const DeviceInfo& device_info) {
46 device_info_ = device_info;
47 }
48
SetComputeMemoryOverlap(bool value)49 void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
50
51 protected:
GetDeviceInfo(const DeviceProperties & device) const52 DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
53 return device_info_;
54 }
55
56 DeviceInfo device_info_;
57 };
58
ExpectZeroCost(const Costs & cost)59 void ExpectZeroCost(const Costs& cost) {
60 EXPECT_TRUE(cost.inaccurate);
61 EXPECT_EQ(cost.compute_time, Costs::Duration::zero());
62 EXPECT_EQ(cost.execution_time, Costs::Duration::zero());
63 EXPECT_EQ(cost.memory_time, Costs::Duration::zero());
64 }
65
66 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)67 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
68 auto input = op_info->add_inputs();
69 auto shape = input->mutable_shape();
70 auto shape_rows = shape->add_dim();
71 shape_rows->set_size(rows);
72 auto shape_columns = shape->add_dim();
73 shape_columns->set_size(columns);
74 input->set_dtype(DT_FLOAT);
75 }
76
SetCpuDevice(OpInfo * op_info)77 void SetCpuDevice(OpInfo* op_info) {
78 auto device = op_info->mutable_device();
79 device->set_type("CPU");
80 device->set_num_cores(10);
81 device->set_bandwidth(10000000); // 10000000 KB/s = 10 GB/s
82 device->set_frequency(1000); // 1000 Mhz = 1 GHz
83 }
84
85 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)86 OpContext DescribeMatMul(int m, int n, int l, int k) {
87 OpContext op_context;
88 SetCpuDevice(&op_context.op_info);
89 op_context.op_info.set_op("MatMul");
90
91 DescribeMatrix(m, l, &op_context.op_info);
92 DescribeMatrix(k, n, &op_context.op_info);
93 return op_context;
94 }
95
96 // Wrangles the minimum number of proto fields to set up an input of
97 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)98 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
99 OpInfo* op_info) {
100 auto input = op_info->add_inputs();
101 input->set_dtype(dtype);
102 auto shape = input->mutable_shape();
103 for (auto d : dims) {
104 shape->add_dim()->set_size(d);
105 }
106 }
107
108 // Wrangles the minimum number of proto fields to set up an output of
109 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)110 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
111 OpInfo* op_info) {
112 auto output = op_info->add_outputs();
113 output->set_dtype(dtype);
114 auto shape = output->mutable_shape();
115 for (auto d : dims) {
116 shape->add_dim()->set_size(d);
117 }
118 }
119
120 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)121 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
122 const std::vector<int>& dims_b,
123 const std::vector<int>& dims_out) {
124 OpContext op_context;
125 SetCpuDevice(&op_context.op_info);
126 op_context.op_info.set_op("SparseTensorDenseMatMul");
127
128 DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
129 DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
130 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
131 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
132 DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
133 return op_context;
134 }
135
136 // Returns an OpInfo for an XlaEinsum
DescribeXlaEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)137 OpContext DescribeXlaEinsum(const std::vector<int>& dims_a,
138 const std::vector<int>& dims_b,
139 const string& equation) {
140 OpContext op_context;
141 SetCpuDevice(&op_context.op_info);
142 op_context.op_info.set_op("XlaEinsum");
143 AttrValue equation_attribute;
144 equation_attribute.set_s(equation);
145 (*op_context.op_info.mutable_attr())["equation"] = equation_attribute;
146 if (!dims_a.empty())
147 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
148 if (!dims_b.empty())
149 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
150 return op_context;
151 }
152
153 // Returns an OpInfo for an Einsum
DescribeEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)154 OpContext DescribeEinsum(const std::vector<int>& dims_a,
155 const std::vector<int>& dims_b,
156 const string& equation) {
157 OpContext op_context = DescribeXlaEinsum(dims_a, dims_b, equation);
158 op_context.op_info.set_op("Einsum");
159 return op_context;
160 }
161
DescribeDummyTensor(OpInfo::TensorProperties * tensor)162 void DescribeDummyTensor(OpInfo::TensorProperties* tensor) {
163 // Intentionally leave the tensor shape and type information missing.
164 }
165
166 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
167 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)168 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
169 auto shape = tensor->mutable_shape();
170 shape->add_dim()->set_size(dim0);
171 tensor->set_dtype(DT_FLOAT);
172 }
173
174 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
175 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)176 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
177 OpInfo::TensorProperties* tensor) {
178 auto shape = tensor->mutable_shape();
179 shape->add_dim()->set_size(dim0);
180 shape->add_dim()->set_size(dim1);
181 shape->add_dim()->set_size(dim2);
182 shape->add_dim()->set_size(dim3);
183 tensor->set_dtype(DT_FLOAT);
184 }
185
186 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
187 // estimation purposes.
DescribeTensor5D(int dim0,int dim1,int dim2,int dim3,int dim4,OpInfo::TensorProperties * tensor)188 void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
189 OpInfo::TensorProperties* tensor) {
190 auto shape = tensor->mutable_shape();
191 shape->add_dim()->set_size(dim0);
192 shape->add_dim()->set_size(dim1);
193 shape->add_dim()->set_size(dim2);
194 shape->add_dim()->set_size(dim3);
195 shape->add_dim()->set_size(dim4);
196 tensor->set_dtype(DT_FLOAT);
197 }
198
199 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
200 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
201 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)202 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
203 int kx, int ky, int oz) {
204 OpContext op_context;
205 SetCpuDevice(&op_context.op_info);
206 op_context.op_info.set_op("Conv2D");
207
208 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
209 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
210
211 return op_context;
212 }
213
214 // Describe DepthwiseConvolution constructs an OpContext for a
215 // DepthwiseConv2dNative applied to an input
216 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
217 // (kx, ky, iz2, cm). cm is channel multiplier
218
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)219 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
220 int iz2, int kx, int ky, int cm) {
221 OpContext op_context;
222 SetCpuDevice(&op_context.op_info);
223 op_context.op_info.set_op("DepthwiseConv2dNative");
224
225 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
226 DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
227
228 return op_context;
229 }
230
231 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
232 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
233 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
234 // bias tensor with shape (oz), a side input tensor with shape
235 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
236 // shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
237 // default to 4 (the vector size most often used with this format on NVIDIA
238 // platforms) for the major channel size, and divide the input channel size by
239 // that amount.
240 //
241 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)242 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
243 int iz2, int kx, int ky, int ox,
244 int oy, int oz, bool has_side_input,
245 const string& data_format,
246 const string& filter_format) {
247 const int kVecWidth = 4;
248 OpContext op_context;
249 SetCpuDevice(&op_context.op_info);
250 op_context.op_info.set_op("FusedConv2DBiasActivation");
251 auto* attr_data_format = op_context.op_info.mutable_attr();
252 SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
253 auto* attr_filter_format = op_context.op_info.mutable_attr();
254 SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
255 if (data_format == "NHWC") {
256 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
257 } else if (data_format == "NCHW") {
258 DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
259 } else {
260 // Use the NCHW_VECT_C format.
261 EXPECT_EQ(data_format, "NCHW_VECT_C");
262 EXPECT_EQ(iz1 % kVecWidth, 0);
263 DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
264 op_context.op_info.add_inputs());
265 }
266 if (filter_format == "HWIO") {
267 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
268 } else if (filter_format == "OIHW") {
269 DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
270 } else {
271 EXPECT_EQ(filter_format, "OIHW_VECT_I");
272 EXPECT_EQ(iz2 % kVecWidth, 0);
273 // Use the OIHW_VECT_I format.
274 DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
275 op_context.op_info.add_inputs());
276 }
277 DescribeTensor1D(oz, op_context.op_info.add_inputs());
278
279 // Add the side_input, if any.
280 auto side_input = op_context.op_info.add_inputs();
281 if (has_side_input) {
282 if (data_format == "NHWC") {
283 DescribeTensor4D(batch, ox, oy, oz, side_input);
284 } else if (data_format == "NCHW") {
285 DescribeTensor4D(batch, oz, ox, oy, side_input);
286 } else {
287 // Use the NCHW_VECT_C format.
288 EXPECT_EQ(data_format, "NCHW_VECT_C");
289 EXPECT_EQ(oz % kVecWidth, 0);
290 DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
291 }
292 }
293
294 // Add the scaling tensors.
295 DescribeTensor1D(1, op_context.op_info.add_inputs());
296 DescribeTensor1D(1, op_context.op_info.add_inputs());
297
298 return op_context;
299 }
300
301 // DescribeUnaryOp constructs an OpContext for the given operation applied to
302 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)303 OpContext DescribeUnaryOp(const string& op, int size1) {
304 OpContext op_context;
305 SetCpuDevice(&op_context.op_info);
306 op_context.op_info.set_op(op);
307
308 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
309 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
310
311 return op_context;
312 }
313
314 // DescribeBinaryOp constructs an OpContext for the given operation applied to
315 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
316 // (2 * size1, size2, 1, 1).
317 //
318 // The choice of dimension here is arbitrary, and is used strictly to test the
319 // cost model for applying elementwise operations to tensors with unequal
320 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)321 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
322 OpContext op_context;
323 SetCpuDevice(&op_context.op_info);
324 op_context.op_info.set_op(op);
325
326 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
327 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
328 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
329
330 return op_context;
331 }
332
333 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
334 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
335 // according to the constraint that the bias must be 1D with size equal to that
336 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)337 OpContext DescribeBiasAdd(int size1, int size2) {
338 OpContext op_context;
339 SetCpuDevice(&op_context.op_info);
340 op_context.op_info.set_op("BiasAdd");
341
342 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
343 DescribeTensor1D(size1, op_context.op_info.add_inputs());
344 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
345
346 return op_context;
347 }
348
GetOutputSize(const int x,const int k,const int s,const string & padding)349 int GetOutputSize(const int x, const int k, const int s,
350 const string& padding) {
351 if (padding == "SAME") {
352 return (x + s - 1) / s;
353 } else {
354 return (x - k + s) / s;
355 }
356 }
357
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)358 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
359 const std::vector<int>& ksize,
360 const std::vector<int>& strides,
361 const string& data_format,
362 const string& padding) {
363 // h, w, and c indices: default with NHWC.
364 int h_index = 1;
365 int w_index = 2;
366 int c_index = 3;
367 if (data_format == "NCHW") {
368 h_index = 2;
369 w_index = 3;
370 c_index = 1;
371 }
372 // Extract parameters.
373 int n = input[0];
374 int h = input[h_index];
375 int w = input[w_index];
376 int c = input[c_index];
377 int sx = strides[h_index];
378 int sy = strides[w_index];
379 int kx = ksize[h_index];
380 int ky = ksize[w_index];
381
382 // Output activation size: default with VALID padding.
383 int ho = GetOutputSize(h, kx, sx, padding);
384 int wo = GetOutputSize(w, ky, sy, padding);
385
386 std::vector<int> output;
387 if (data_format == "NHWC") {
388 output = {n, ho, wo, c};
389 } else {
390 output = {n, c, ho, wo};
391 }
392 return output;
393 }
394
395 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64> & shape,const std::vector<int64> values,const bool tensor_content,TensorProto * tensor_proto)396 void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
397 const std::vector<int64> values, const bool tensor_content,
398 TensorProto* tensor_proto) {
399 tensor_proto->Clear();
400 TensorProto temp_tensor_proto;
401 temp_tensor_proto.set_dtype(dtype);
402 for (const auto& x : shape) {
403 temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
404 }
405 for (const auto& x : values) {
406 if (dtype == DT_INT64) {
407 temp_tensor_proto.add_int64_val(x);
408 } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
409 dtype == DT_UINT8) {
410 temp_tensor_proto.add_int_val(x);
411 } else if (dtype == DT_UINT32) {
412 temp_tensor_proto.add_uint32_val(x);
413 } else if (dtype == DT_UINT64) {
414 temp_tensor_proto.add_uint64_val(x);
415 } else {
416 CHECK(false) << "Unsupported dtype: " << dtype;
417 }
418 }
419 Tensor tensor(dtype);
420 CHECK(tensor.FromProto(temp_tensor_proto));
421 if (tensor_content) {
422 tensor.AsProtoTensorContent(tensor_proto);
423 } else {
424 tensor.AsProtoField(tensor_proto);
425 }
426 }
427
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)428 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
429 const std::vector<int>& ksize,
430 const std::vector<int>& strides,
431 const string& data_format, const string& padding) {
432 OpContext op_context;
433 auto& op_info = op_context.op_info;
434 SetCpuDevice(&op_info);
435 op_info.set_op(op_name);
436
437 const std::vector<int> y =
438 GetPoolingOutputSize(x, ksize, strides, data_format, padding);
439 if (op_name == "AvgPool" || op_name == "MaxPool") {
440 // input: x, output: y.
441 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
442 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
443 } else if (op_name == "AvgPoolGrad") {
444 // input: x's shape, y_grad, output: x_grad.
445 DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
446 auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
447 GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
448 /*tensor_content=*/false, tensor_proto);
449 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
450 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
451 } else if (op_name == "MaxPoolGrad") {
452 // input: x, y, y_grad, output: x_grad.
453 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
454 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
455 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
456 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
457 }
458 auto* attr = op_info.mutable_attr();
459 SetAttrValue(data_format, &(*attr)["data_format"]);
460 SetAttrValue(padding, &(*attr)["padding"]);
461 SetAttrValue(strides, &(*attr)["strides"]);
462 SetAttrValue(ksize, &(*attr)["ksize"]);
463 return op_context;
464 }
465
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)466 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
467 const std::vector<int>& x,
468 const string& data_format) {
469 // First, get MaxPool op info with unit stride and unit window.
470 OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
471 {1, 1, 1, 1}, data_format, "SAME");
472 auto& op_info = op_context.op_info;
473 // Override op name.
474 if (is_grad) {
475 op_info.set_op("FusedBatchNormGrad");
476 } else {
477 op_info.set_op("FusedBatchNorm");
478 }
479
480 // Add additional input output tensors.
481 if (is_grad) {
482 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
483 }
484 int num_1d_inputs = is_grad ? 3 : 4;
485 for (int i = 0; i < num_1d_inputs; i++) {
486 auto* tensor = op_info.add_inputs();
487 auto* shape = tensor->mutable_shape();
488 shape->add_dim()->set_size(x[3]);
489 tensor->set_dtype(DT_FLOAT);
490 }
491 for (int i = 0; i < 4; i++) {
492 auto* tensor = op_info.add_outputs();
493 auto* shape = tensor->mutable_shape();
494 shape->add_dim()->set_size(x[3]);
495 tensor->set_dtype(DT_FLOAT);
496 }
497
498 // Delete unnecessary attr.
499 auto* attr = op_context.op_info.mutable_attr();
500 attr->erase("ksize");
501 attr->erase("strides");
502 attr->erase("padding");
503
504 // Additional attrs for FusedBatchNorm.
505 SetAttrValue(is_training, &(*attr)["is_training"]);
506
507 return op_context;
508 }
509 } // namespace
510
511 class OpLevelCostEstimatorTest : public ::testing::Test {
512 protected:
513 using BatchMatMulDimensions = OpLevelCostEstimator::BatchMatMulDimensions;
514
PredictCosts(const OpContext & op_context) const515 Costs PredictCosts(const OpContext& op_context) const {
516 return estimator_.PredictCosts(op_context);
517 }
518
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const519 int64 CountMatMulOperations(const OpInfo& op_info,
520 bool* found_unknown_shapes) const {
521 return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
522 }
523
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const524 int64 CountBatchMatMulOperations(const OpInfo& op_info,
525 bool* found_unknown_shapes) const {
526 return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
527 }
528
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes) const529 int64 CountBatchMatMulOperations(const OpInfo& op_info,
530 BatchMatMulDimensions* batch_mat_mul,
531 bool* found_unknown_shapes) const {
532 return estimator_.CountBatchMatMulOperations(op_info, batch_mat_mul,
533 found_unknown_shapes);
534 }
535
SetComputeMemoryOverlap(bool value)536 void SetComputeMemoryOverlap(bool value) {
537 estimator_.compute_memory_overlap_ = value;
538 }
539
ValidateOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)540 void ValidateOpDimensionsFromInputs(const int n, const int h, const int w,
541 const int c, const int kx, const int ky,
542 const int sx, const int sy,
543 const string& data_format,
544 const string& padding) {
545 OpContext op_context;
546 int ho;
547 int wo;
548 if (data_format == "NHWC") {
549 op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
550 {1, sx, sy, 1}, "NHWC", padding);
551 ho = op_context.op_info.outputs(0).shape().dim(1).size();
552 wo = op_context.op_info.outputs(0).shape().dim(2).size();
553 } else {
554 op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
555 {1, 1, sx, sy}, "NCHW", padding);
556 ho = op_context.op_info.outputs(0).shape().dim(2).size();
557 wo = op_context.op_info.outputs(0).shape().dim(3).size();
558 }
559
560 bool found_unknown_shapes;
561 auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
562 op_context.op_info.inputs(0).shape(), op_context.op_info,
563 &found_unknown_shapes);
564 Padding padding_enum;
565 if (padding == "VALID") {
566 padding_enum = Padding::VALID;
567 } else {
568 padding_enum = Padding::SAME;
569 }
570 EXPECT_EQ(n, dims.batch);
571 EXPECT_EQ(h, dims.ix);
572 EXPECT_EQ(w, dims.iy);
573 EXPECT_EQ(c, dims.iz);
574 EXPECT_EQ(kx, dims.kx);
575 EXPECT_EQ(ky, dims.ky);
576 EXPECT_EQ(sx, dims.sx);
577 EXPECT_EQ(sy, dims.sy);
578 EXPECT_EQ(ho, dims.ox);
579 EXPECT_EQ(wo, dims.oy);
580 EXPECT_EQ(c, dims.oz);
581 EXPECT_EQ(padding_enum, dims.padding);
582 }
583
584 OpLevelCostEstimator estimator_;
585 };
586
587 class OpLevelBatchMatMulCostEstimatorTest
588 : public OpLevelCostEstimatorTest,
589 public ::testing::WithParamInterface<const char*> {
590 protected:
591 // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)592 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
593 const std::vector<int>& dims_b) {
594 OpContext op_context;
595 SetCpuDevice(&op_context.op_info);
596 op_context.op_info.set_op(GetParam());
597
598 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
599 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
600 return op_context;
601 }
602
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const603 int64 CountBatchMatMulOperations(const OpInfo& op_info,
604 bool* found_unknown_shapes) const {
605 return OpLevelCostEstimatorTest::CountBatchMatMulOperations(
606 op_info, found_unknown_shapes);
607 }
608
CountBatchMatMulDimProduct(const OpInfo & op_info,bool * found_unknown_shapes) const609 int64 CountBatchMatMulDimProduct(const OpInfo& op_info,
610 bool* found_unknown_shapes) const {
611 BatchMatMulDimensions batch_mat_mul;
612
613 batch_mat_mul.matmul_dims.n = 0;
614 batch_mat_mul.matmul_dims.m = 0;
615 batch_mat_mul.matmul_dims.k = 0;
616
617 OpLevelCostEstimatorTest::CountBatchMatMulOperations(
618 op_info, &batch_mat_mul, found_unknown_shapes);
619 int dimension_product = 1;
620 for (auto dim : batch_mat_mul.batch_dims) dimension_product *= dim;
621
622 dimension_product *= batch_mat_mul.matmul_dims.n;
623 dimension_product *= batch_mat_mul.matmul_dims.m;
624 dimension_product *= batch_mat_mul.matmul_dims.k;
625
626 return dimension_product;
627 }
628 };
629
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)630 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
631 OpContext op_context;
632 SetCpuDevice(&op_context.op_info);
633 std::unordered_set<string> persistent_ops = {
634 "Const", "Variable", "VariableV2", "AutoReloadVariable",
635 "VarHandleOp", "ReadVariableOp",
636 };
637 // Minimum cost for all persistent ops.
638 for (const auto& op : persistent_ops) {
639 op_context.op_info.set_op(op);
640 auto cost = estimator_.PredictCosts(op_context);
641 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
642 EXPECT_EQ(Costs::Duration(1), cost.compute_time);
643 EXPECT_EQ(Costs::Duration(1), cost.execution_time);
644 EXPECT_EQ(cost.num_ops_total, 1);
645 EXPECT_FALSE(cost.inaccurate);
646 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
647 EXPECT_EQ(cost.temporary_memory, 0);
648 EXPECT_EQ(cost.persistent_memory, 0);
649 }
650 }
651
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)652 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
653 std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
654
655 for (const auto& op : gather_ops) {
656 OpContext op_context;
657 SetCpuDevice(&op_context.op_info);
658 op_context.op_info.set_op(op);
659
660 // Huge first input shouldn't affect Gather execution and memory costs.
661 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
662 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
663 DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
664
665 auto cost = estimator_.PredictCosts(op_context);
666 EXPECT_EQ(Costs::Duration(130), cost.memory_time);
667 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
668 EXPECT_EQ(Costs::Duration(146), cost.execution_time);
669 EXPECT_EQ(cost.num_ops_total, 1);
670 EXPECT_FALSE(cost.inaccurate);
671 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
672 EXPECT_EQ(cost.temporary_memory, 0);
673 EXPECT_EQ(cost.persistent_memory, 0);
674 }
675 }
676
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)677 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
678 OpContext op_context;
679 SetCpuDevice(&op_context.op_info);
680 op_context.op_info.set_op("Gather");
681
682 // Huge first input shouldn't affect Gather execution and memory costs.
683 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
684 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
685
686 auto cost = estimator_.PredictCosts(op_context);
687 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
688 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
689 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
690 EXPECT_EQ(1, cost.num_ops_total);
691 EXPECT_TRUE(cost.inaccurate);
692 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
693 EXPECT_EQ(cost.temporary_memory, 0);
694 EXPECT_EQ(cost.persistent_memory, 0);
695 }
696
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)697 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
698 OpContext op_context;
699 SetCpuDevice(&op_context.op_info);
700 op_context.op_info.set_op("Slice");
701
702 // Huge first input shouldn't affect Slice execution and memory costs.
703 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
704 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
705 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
706 DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
707
708 auto cost = estimator_.PredictCosts(op_context);
709 EXPECT_EQ(Costs::Duration(81), cost.memory_time);
710 EXPECT_EQ(Costs::Duration(10), cost.compute_time);
711 EXPECT_EQ(Costs::Duration(91), cost.execution_time);
712 EXPECT_EQ(cost.num_ops_total, 1);
713 EXPECT_FALSE(cost.inaccurate);
714 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
715 EXPECT_EQ(cost.temporary_memory, 0);
716 EXPECT_EQ(cost.persistent_memory, 0);
717 }
718
TEST_F(OpLevelCostEstimatorTest,TestStridedSliceCosts)719 TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
720 OpContext op_context;
721 SetCpuDevice(&op_context.op_info);
722 op_context.op_info.set_op("StridedSlice");
723
724 // Huge first input shouldn't affect StridedSlice execution and memory costs.
725 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
726 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
727 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
728 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
729 DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
730
731 auto cost = estimator_.PredictCosts(op_context);
732 EXPECT_EQ(Costs::Duration(81), cost.memory_time);
733 EXPECT_EQ(Costs::Duration(10), cost.compute_time);
734 EXPECT_EQ(Costs::Duration(91), cost.execution_time);
735 EXPECT_EQ(cost.num_ops_total, 1);
736 EXPECT_FALSE(cost.inaccurate);
737 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
738 EXPECT_EQ(cost.temporary_memory, 0);
739 EXPECT_EQ(cost.persistent_memory, 0);
740 }
741
TEST_F(OpLevelCostEstimatorTest,TestScatterOps)742 TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
743 std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
744 "ScatterMin", "ScatterMul", "ScatterSub",
745 "ScatterUpdate"};
746 for (const auto& op : scatter_ops) {
747 // Test updates.shape = indices.shape + ref.shape[1:]
748 {
749 OpContext op_context;
750 SetCpuDevice(&op_context.op_info);
751 op_context.op_info.set_op(op);
752 // Huge first dimension in input shouldn't affect Scatter execution and
753 // memory costs.
754 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
755 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
756 DescribeArbitraryRankInput({16, 10}, DT_FLOAT, &op_context.op_info);
757 DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
758 &op_context.op_info);
759
760 auto cost = estimator_.PredictCosts(op_context);
761 EXPECT_EQ(Costs::Duration(205), cost.memory_time);
762 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
763 EXPECT_EQ(Costs::Duration(221), cost.execution_time);
764 EXPECT_EQ(cost.num_ops_total, 1);
765 EXPECT_FALSE(cost.inaccurate);
766 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
767 EXPECT_EQ(cost.temporary_memory, 0);
768 EXPECT_EQ(cost.persistent_memory, 0);
769 }
770
771 // Test updates.shape = [] and INT32 indices
772 {
773 OpContext op_context;
774 SetCpuDevice(&op_context.op_info);
775 op_context.op_info.set_op(op);
776 // Huge first dimension in input shouldn't affect Scatter execution and
777 // memory costs.
778 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
779 DescribeArbitraryRankInput({16}, DT_INT32, &op_context.op_info);
780 DescribeArbitraryRankInput({}, DT_FLOAT, &op_context.op_info);
781 DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
782 &op_context.op_info);
783
784 auto cost = estimator_.PredictCosts(op_context);
785 EXPECT_EQ(Costs::Duration(135), cost.memory_time);
786 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
787 EXPECT_EQ(Costs::Duration(151), cost.execution_time);
788 EXPECT_EQ(1, cost.num_ops_total);
789 EXPECT_FALSE(cost.inaccurate);
790 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
791 }
792 }
793 }
794
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)795 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
796 auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
797 EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
798 EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
799 EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
800 EXPECT_EQ(cost.num_ops_total, 1);
801 EXPECT_FALSE(cost.inaccurate);
802 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
803 EXPECT_EQ(cost.temporary_memory, 0);
804 EXPECT_EQ(cost.persistent_memory, 0);
805 }
806
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)807 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
808 auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
809 EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
810 EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
811 EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
812 EXPECT_EQ(cost.num_ops_total, 1);
813 EXPECT_FALSE(cost.inaccurate);
814 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
815 EXPECT_EQ(cost.temporary_memory, 0);
816 EXPECT_EQ(cost.persistent_memory, 0);
817 }
818
TEST_F(OpLevelCostEstimatorTest,InvalidConv2DConfig)819 TEST_F(OpLevelCostEstimatorTest, InvalidConv2DConfig) {
820 // Convolution ops.
821 const std::vector<const std::string> conv_ops = {
822 "Conv2D",
823 "Conv2DBackpropFilter",
824 "Conv2DBackpropInput",
825 "DepthwiseConv2dNative",
826 "DepthwiseConv2dNativeBackpropFilter",
827 "DepthwiseConv2dNativeBackpropInput",
828 };
829 // A valid Conv2D config.
830 const std::vector<int> valid_conv_config = {16, 19, 19, 48, 48, 5, 5, 256};
831 for (const auto& op : conv_ops) {
832 // Test with setting one value in conv config to zero.
833 // PredictCosts() should return zero costs.
834 for (int i = 0; i < valid_conv_config.size(); ++i) {
835 std::vector<int> conv_config(valid_conv_config);
836 conv_config[i] = 0;
837 auto op_context = DescribeConvolution(
838 conv_config[0], conv_config[1], conv_config[2], conv_config[3],
839 conv_config[4], conv_config[5], conv_config[6], conv_config[7]);
840 op_context.op_info.set_op(op);
841 auto cost = PredictCosts(op_context);
842 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
843 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
844 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
845 EXPECT_EQ(1, cost.num_ops_total);
846 EXPECT_TRUE(cost.inaccurate);
847 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
848 }
849 }
850 }
851
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)852 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
853 auto cost =
854 PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
855 EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
856 EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
857 EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
858 EXPECT_EQ(cost.num_ops_total, 1);
859 EXPECT_FALSE(cost.inaccurate);
860 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
861 EXPECT_EQ(cost.temporary_memory, 0);
862 EXPECT_EQ(cost.persistent_memory, 0);
863 }
864
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)865 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
866 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
867 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
868 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
869 EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
870 EXPECT_EQ(cost.num_ops_total, 1);
871 EXPECT_TRUE(cost.inaccurate);
872 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
873 EXPECT_EQ(cost.temporary_memory, 0);
874 EXPECT_EQ(cost.persistent_memory, 0);
875 }
876
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)877 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
878 SetComputeMemoryOverlap(true);
879 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
880 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
881 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
882 EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200)
883 EXPECT_EQ(cost.num_ops_total, 1);
884 EXPECT_TRUE(cost.inaccurate);
885 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
886 EXPECT_EQ(cost.temporary_memory, 0);
887 EXPECT_EQ(cost.persistent_memory, 0);
888 SetComputeMemoryOverlap(false); // Set it back to default.
889 }
890
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)891 TEST_F(OpLevelCostEstimatorTest,
892 FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
893 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
894 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
895 "NCHW", "HWIO"));
896 EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
897 EXPECT_EQ(Costs::Duration(355321037), cost.compute_time);
898 EXPECT_EQ(Costs::Duration(356146382), cost.execution_time);
899 EXPECT_EQ(cost.num_ops_total, 1);
900 EXPECT_FALSE(cost.inaccurate);
901 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
902 EXPECT_EQ(cost.temporary_memory, 0);
903 EXPECT_EQ(cost.persistent_memory, 0);
904 }
905
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)906 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
907 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
908 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
909 "NCHW", "HWIO"));
910 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
911 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
912 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
913 EXPECT_EQ(cost.num_ops_total, 1);
914 EXPECT_FALSE(cost.inaccurate);
915 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
916 EXPECT_EQ(cost.temporary_memory, 0);
917 EXPECT_EQ(cost.persistent_memory, 0);
918 }
919
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)920 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
921 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
922 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
923 "NCHW", "OIHW"));
924 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
925 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
926 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
927 EXPECT_EQ(cost.num_ops_total, 1);
928 EXPECT_FALSE(cost.inaccurate);
929 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
930 EXPECT_EQ(cost.temporary_memory, 0);
931 EXPECT_EQ(cost.persistent_memory, 0);
932 }
933
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)934 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
935 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
936 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
937 "NHWC", "HWIO"));
938 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
939 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
940 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
941 EXPECT_EQ(cost.num_ops_total, 1);
942 EXPECT_FALSE(cost.inaccurate);
943 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
944 EXPECT_EQ(cost.temporary_memory, 0);
945 EXPECT_EQ(cost.persistent_memory, 0);
946 }
947
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)948 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
949 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
950 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
951 "NHWC", "OIHW"));
952 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
953 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
954 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
955 EXPECT_EQ(cost.num_ops_total, 1);
956 EXPECT_FALSE(cost.inaccurate);
957 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
958 EXPECT_EQ(cost.temporary_memory, 0);
959 EXPECT_EQ(cost.persistent_memory, 0);
960 }
961
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)962 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
963 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
964 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
965 "NCHW_VECT_C", "OIHW"));
966 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
967 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
968 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
969 EXPECT_EQ(cost.num_ops_total, 1);
970 EXPECT_FALSE(cost.inaccurate);
971 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
972 EXPECT_EQ(cost.temporary_memory, 0);
973 EXPECT_EQ(cost.persistent_memory, 0);
974 }
975
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)976 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
977 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
978 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
979 "NCHW", "OIHW_VECT_I"));
980 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
981 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
982 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
983 EXPECT_EQ(cost.num_ops_total, 1);
984 EXPECT_FALSE(cost.inaccurate);
985 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
986 EXPECT_EQ(cost.temporary_memory, 0);
987 EXPECT_EQ(cost.persistent_memory, 0);
988 }
989
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I)990 TEST_F(OpLevelCostEstimatorTest,
991 FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
992 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
993 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
994 "NCHW_VECT_C", "OIHW_VECT_I"));
995 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
996 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
997 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
998 EXPECT_EQ(cost.num_ops_total, 1);
999 EXPECT_FALSE(cost.inaccurate);
1000 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1001 EXPECT_EQ(cost.temporary_memory, 0);
1002 EXPECT_EQ(cost.persistent_memory, 0);
1003 }
1004
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)1005 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
1006 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
1007 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1008 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1009 EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
1010 EXPECT_EQ(cost.num_ops_total, 1);
1011 EXPECT_FALSE(cost.inaccurate);
1012 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1013 EXPECT_EQ(cost.temporary_memory, 0);
1014 EXPECT_EQ(cost.persistent_memory, 0);
1015 }
1016
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)1017 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
1018 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
1019 EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
1020 EXPECT_EQ(Costs::Duration(400), cost.compute_time);
1021 EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
1022 EXPECT_EQ(cost.num_ops_total, 1);
1023 EXPECT_FALSE(cost.inaccurate);
1024 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1025 EXPECT_EQ(cost.temporary_memory, 0);
1026 EXPECT_EQ(cost.persistent_memory, 0);
1027 }
1028
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)1029 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
1030 auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
1031 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1032 EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
1033 EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
1034 EXPECT_EQ(cost.num_ops_total, 1);
1035 EXPECT_FALSE(cost.inaccurate);
1036 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1037 EXPECT_EQ(cost.temporary_memory, 0);
1038 EXPECT_EQ(cost.persistent_memory, 0);
1039 }
1040
TEST_F(OpLevelCostEstimatorTest,SquaredDifferenceExecutionTime)1041 TEST_F(OpLevelCostEstimatorTest, SquaredDifferenceExecutionTime) {
1042 auto cost = PredictCosts(DescribeBinaryOp("SquaredDifference", 1000, 2));
1043 EXPECT_EQ(cost.memory_time, Costs::Duration(3600));
1044 EXPECT_EQ(cost.compute_time, Costs::Duration(800));
1045 EXPECT_EQ(cost.execution_time, Costs::Duration(4400));
1046 EXPECT_EQ(cost.num_ops_total, 1);
1047 EXPECT_FALSE(cost.inaccurate);
1048 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1049 EXPECT_EQ(cost.temporary_memory, 0);
1050 EXPECT_EQ(cost.persistent_memory, 0);
1051 }
1052
TEST_F(OpLevelCostEstimatorTest,UnaryOpExecutionTime)1053 TEST_F(OpLevelCostEstimatorTest, UnaryOpExecutionTime) {
1054 std::vector<std::pair<std::string, int>> unary_ops = {
1055 {"All", 1}, {"ArgMax", 1}, {"Cast", 1}, {"Max", 1},
1056 {"Min", 1}, {"Prod", 1}, {"Relu", 1}, {"Relu6", 1},
1057 {"Softmax", 43}, {"Sum", 1}, {"TopKV2", 1}};
1058
1059 const int kTensorSize = 1000;
1060 for (auto unary_op : unary_ops) {
1061 OpContext op_context = DescribeUnaryOp(unary_op.first, kTensorSize);
1062
1063 const int kExpectedMemoryTime = 800;
1064 int expected_compute_time = std::ceil(
1065 unary_op.second * kTensorSize /
1066 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1067
1068 auto cost = PredictCosts(op_context);
1069 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
1070 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time))
1071 << unary_op.first;
1072 EXPECT_EQ(cost.execution_time,
1073 Costs::Duration(expected_compute_time + kExpectedMemoryTime));
1074 EXPECT_EQ(cost.num_ops_total, 1);
1075 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1076 EXPECT_FALSE(cost.inaccurate);
1077 EXPECT_EQ(cost.temporary_memory, 0);
1078 EXPECT_EQ(cost.persistent_memory, 0);
1079 }
1080 }
1081
TEST_F(OpLevelCostEstimatorTest,BinaryOpExecutionTime)1082 TEST_F(OpLevelCostEstimatorTest, BinaryOpExecutionTime) {
1083 std::vector<std::pair<std::string, int>> binary_ops = {
1084 {"Select", 1},
1085 {"SelectV2", 1},
1086 {"SquaredDifference", 2},
1087 {"Where", 1},
1088 };
1089
1090 const int kTensorSize1 = 1000;
1091 const int kTensorSize2 = 2;
1092 for (auto binary_op : binary_ops) {
1093 OpContext op_context =
1094 DescribeBinaryOp(binary_op.first, kTensorSize1, kTensorSize2);
1095
1096 const int kExpectedMemoryTime = 3600;
1097 int expected_compute_time = std::ceil(
1098 binary_op.second * kTensorSize1 * kTensorSize2 * 2 /
1099 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1100
1101 auto cost = PredictCosts(op_context);
1102 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time)
1103 << binary_op.first;
1104 EXPECT_EQ(Costs::Duration(expected_compute_time), cost.compute_time)
1105 << binary_op.first;
1106 EXPECT_EQ(Costs::Duration(expected_compute_time + kExpectedMemoryTime),
1107 cost.execution_time)
1108 << binary_op.first;
1109 EXPECT_EQ(cost.num_ops_total, 1);
1110 EXPECT_FALSE(cost.inaccurate);
1111 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1112 EXPECT_EQ(cost.temporary_memory, 0);
1113 EXPECT_EQ(cost.persistent_memory, 0);
1114 }
1115 }
1116
TEST_F(OpLevelCostEstimatorTest,BroadcastAddExecutionTime)1117 TEST_F(OpLevelCostEstimatorTest, BroadcastAddExecutionTime) {
1118 OpContext op_context;
1119 SetCpuDevice(&op_context.op_info);
1120 op_context.op_info.set_op("Add");
1121
1122 DescribeTensor1D(100, op_context.op_info.add_inputs());
1123 DescribeTensor4D(1, 10, 1, 1, op_context.op_info.add_inputs());
1124
1125 auto cost = PredictCosts(op_context);
1126 EXPECT_EQ(Costs::Duration(44), cost.memory_time);
1127 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
1128 EXPECT_EQ(Costs::Duration(144), cost.execution_time);
1129 EXPECT_EQ(cost.num_ops_total, 1);
1130 EXPECT_FALSE(cost.inaccurate);
1131 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1132 EXPECT_EQ(cost.temporary_memory, 0);
1133 EXPECT_EQ(cost.persistent_memory, 0);
1134 }
1135
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)1136 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
1137 {
1138 auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
1139 EXPECT_EQ(1, cost.num_ops_total);
1140 EXPECT_FALSE(cost.inaccurate);
1141 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1142 }
1143 {
1144 auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
1145 EXPECT_EQ(1, cost.num_ops_total);
1146 EXPECT_TRUE(cost.inaccurate);
1147 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1148 }
1149 {
1150 auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
1151 EXPECT_EQ(1, cost.num_ops_total);
1152 EXPECT_TRUE(cost.inaccurate);
1153 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1154 }
1155 {
1156 auto cost =
1157 PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1158 EXPECT_EQ(1, cost.num_ops_total);
1159 EXPECT_FALSE(cost.inaccurate);
1160 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1161 }
1162 {
1163 auto cost =
1164 PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
1165 EXPECT_EQ(1, cost.num_ops_total);
1166 EXPECT_TRUE(cost.inaccurate);
1167 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1168 }
1169 }
1170
TEST_P(OpLevelBatchMatMulCostEstimatorTest,TestBatchMatMul)1171 TEST_P(OpLevelBatchMatMulCostEstimatorTest, TestBatchMatMul) {
1172 {
1173 auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
1174 EXPECT_EQ(1, cost.num_ops_total);
1175 EXPECT_TRUE(cost.inaccurate);
1176 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1177 }
1178 {
1179 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
1180 EXPECT_EQ(1, cost.num_ops_total);
1181 EXPECT_TRUE(cost.inaccurate);
1182 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1183 }
1184 {
1185 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
1186 EXPECT_EQ(1, cost.num_ops_total);
1187 EXPECT_FALSE(cost.inaccurate);
1188 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1189 }
1190 {
1191 auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
1192 EXPECT_EQ(1, cost.num_ops_total);
1193 EXPECT_FALSE(cost.inaccurate);
1194 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1195 }
1196 {
1197 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
1198 EXPECT_EQ(1, cost.num_ops_total);
1199 EXPECT_FALSE(cost.inaccurate);
1200 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1201 }
1202 bool matmul_inaccurate = false;
1203 bool batch_matmul_inaccurate = false;
1204 EXPECT_EQ(
1205 CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1206 &matmul_inaccurate),
1207 CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
1208 &batch_matmul_inaccurate));
1209 EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
1210 EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1211 &matmul_inaccurate),
1212 CountBatchMatMulOperations(
1213 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
1214 &batch_matmul_inaccurate));
1215 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1216 EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1217 &matmul_inaccurate),
1218 CountBatchMatMulOperations(
1219 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
1220 &batch_matmul_inaccurate));
1221 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1222
1223 // Test the count to make sure that they extracted the dimensions correctly
1224 int prod = CountBatchMatMulDimProduct(
1225 DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}).op_info,
1226 &batch_matmul_inaccurate);
1227 EXPECT_EQ(prod, 16);
1228 EXPECT_FALSE(batch_matmul_inaccurate);
1229
1230 // Exercise the bad cases of a batchMatMul.
1231 OpContext bad_batch = DescribeBatchMatMul({2, 4}, {4, 2});
1232 bad_batch.op_info.set_op("notBatchMatMul");
1233 prod =
1234 CountBatchMatMulDimProduct(bad_batch.op_info, &batch_matmul_inaccurate);
1235
1236 EXPECT_EQ(prod, 0);
1237 EXPECT_TRUE(batch_matmul_inaccurate);
1238
1239 // Exercise a transpose case of a batchMatMul
1240 OpContext transpose_batch = DescribeBatchMatMul({2, 4, 3, 1}, {4, 2});
1241 auto attr = transpose_batch.op_info.mutable_attr();
1242 (*attr)["adj_x"].set_b(true);
1243 (*attr)["adj_y"].set_b(true);
1244
1245 prod = CountBatchMatMulDimProduct(transpose_batch.op_info,
1246 &batch_matmul_inaccurate);
1247 EXPECT_EQ(prod, 12);
1248 }
1249 INSTANTIATE_TEST_SUITE_P(TestBatchMatMul, OpLevelBatchMatMulCostEstimatorTest,
1250 ::testing::Values("BatchMatMul", "BatchMatMulV2"));
1251
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)1252 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
1253 // Unknown shape cases
1254 {
1255 auto cost =
1256 PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
1257 EXPECT_EQ(1, cost.num_ops_total);
1258 EXPECT_TRUE(cost.inaccurate);
1259 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1260 }
1261 {
1262 auto cost =
1263 PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
1264 EXPECT_EQ(1, cost.num_ops_total);
1265 EXPECT_TRUE(cost.inaccurate);
1266 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1267 }
1268 {
1269 auto cost =
1270 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
1271 EXPECT_EQ(1, cost.num_ops_total);
1272 EXPECT_TRUE(cost.inaccurate);
1273 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1274 }
1275 {
1276 auto cost =
1277 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
1278 EXPECT_EQ(1, cost.num_ops_total);
1279 EXPECT_TRUE(cost.inaccurate);
1280 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1281 }
1282 // Known shape cases
1283 {
1284 auto cost = PredictCosts(
1285 DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
1286 EXPECT_EQ(1, cost.num_ops_total);
1287 EXPECT_FALSE(cost.inaccurate);
1288 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1289 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1290 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1291 }
1292 {
1293 // Same cost as above case because cost does not depend on k_dim
1294 auto cost = PredictCosts(
1295 DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
1296 EXPECT_EQ(1, cost.num_ops_total);
1297 EXPECT_FALSE(cost.inaccurate);
1298 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1299 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1300 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1301 }
1302 }
1303
ExpectTensorShape(const std::vector<int64> & expected,const TensorShapeProto & tensor_shape_proto)1304 void ExpectTensorShape(const std::vector<int64>& expected,
1305 const TensorShapeProto& tensor_shape_proto) {
1306 TensorShape tensor_shape_expected(expected);
1307 TensorShape tensor_shape(tensor_shape_proto);
1308
1309 EXPECT_EQ(tensor_shape_expected, tensor_shape);
1310 }
1311
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)1312 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
1313 TensorProto tensor_proto;
1314 TensorShapeProto tensor_shape_proto;
1315
1316 // Dimension larger than max value; should fail while converting to
1317 // Tensor class.
1318 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
1319 EXPECT_FALSE(
1320 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1321
1322 tensor_proto.Clear();
1323 // Expect only 1D shape.
1324 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
1325 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
1326 EXPECT_FALSE(
1327 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1328
1329 // Expect only handle integer data types.
1330 GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
1331 EXPECT_FALSE(
1332 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1333
1334 // Check GetTensorShapeProtoFromTensorProto() returns correct values.
1335 {
1336 std::vector<int64> shape_expected = {10, 20, 30, 40};
1337 GetTensorProto(DT_INT32, {4}, shape_expected,
1338 /*tensor_content=*/false, &tensor_proto);
1339 EXPECT_TRUE(
1340 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1341 ExpectTensorShape(shape_expected, tensor_shape_proto);
1342 }
1343
1344 {
1345 std::vector<int64> shape_expected = {40, 20, 90, 40};
1346 GetTensorProto(DT_INT64, {4}, shape_expected,
1347 /*tensor_content=*/false, &tensor_proto);
1348 EXPECT_TRUE(
1349 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1350 ExpectTensorShape(shape_expected, tensor_shape_proto);
1351 }
1352
1353 {
1354 std::vector<int64> shape_expected = {10, 20, 30, 40};
1355 GetTensorProto(DT_INT32, {4}, shape_expected,
1356 /*tensor_content=*/true, &tensor_proto);
1357 EXPECT_TRUE(
1358 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1359 ExpectTensorShape(shape_expected, tensor_shape_proto);
1360 }
1361
1362 {
1363 std::vector<int64> shape_expected = {40, 20, 90, 40};
1364 GetTensorProto(DT_INT64, {4}, shape_expected,
1365 /*tensor_content=*/true, &tensor_proto);
1366 EXPECT_TRUE(
1367 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1368 ExpectTensorShape(shape_expected, tensor_shape_proto);
1369 }
1370 }
1371
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)1372 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
1373 std::vector<string> paddings = {"VALID", "SAME"};
1374 std::vector<string> formats = {"NHWC", "NCHW"};
1375 for (const auto& p : paddings) {
1376 for (const auto& f : formats) {
1377 // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1378 ValidateOpDimensionsFromInputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1379 ValidateOpDimensionsFromInputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1380 ValidateOpDimensionsFromInputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1381 ValidateOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1382 }
1383 }
1384 }
1385
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1386 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1387 auto predict_max_pool = [this](const int n, const int in, const int c,
1388 const int k, const int s,
1389 const string& padding) -> Costs {
1390 OpContext op_context = DescribePoolingOp(
1391 "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1392 return estimator_.PredictCosts(op_context);
1393 };
1394
1395 {
1396 // Typical 3xz3 window with 2x2 stride.
1397 auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1398 EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1399 EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1400 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1401 EXPECT_EQ(costs.num_ops_total, 1);
1402 EXPECT_FALSE(costs.inaccurate);
1403 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1404 EXPECT_EQ(costs.temporary_memory, 0);
1405 EXPECT_EQ(costs.persistent_memory, 0);
1406 }
1407 {
1408 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1409 auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1410 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1411 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1412 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1413 EXPECT_EQ(1, costs.num_ops_total);
1414 EXPECT_FALSE(costs.inaccurate);
1415 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1416 }
1417 {
1418 // 2x2 window with 3x3 stride.
1419 auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1420 EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1421 EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1422 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1423 EXPECT_EQ(1, costs.num_ops_total);
1424 EXPECT_FALSE(costs.inaccurate);
1425 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1426 }
1427 }
1428
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1429 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1430 auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1431 const int k, const int s,
1432 const string& padding) -> Costs {
1433 OpContext op_context =
1434 DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1435 {1, s, s, 1}, "NHWC", padding);
1436 return estimator_.PredictCosts(op_context);
1437 };
1438
1439 {
1440 // Typical 3x3 window with 2x2 stride.
1441 auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1442 EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1443 EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1444 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1445 EXPECT_EQ(costs.num_ops_total, 1);
1446 EXPECT_FALSE(costs.inaccurate);
1447 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1448 EXPECT_EQ(costs.temporary_memory, 0);
1449 EXPECT_EQ(costs.persistent_memory, 0);
1450 }
1451 {
1452 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1453 auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1454 EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1455 EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1456 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1457 EXPECT_EQ(1, costs.num_ops_total);
1458 EXPECT_FALSE(costs.inaccurate);
1459 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1460 }
1461 {
1462 // 2x2 window with 3x3 stride.
1463 auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1464 EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1465 EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1466 EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1467 EXPECT_EQ(1, costs.num_ops_total);
1468 EXPECT_FALSE(costs.inaccurate);
1469 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1470 }
1471 }
1472
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1473 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1474 auto predict_avg_pool = [this](const int n, const int in, const int c,
1475 const int k, const int s,
1476 const string& padding) -> Costs {
1477 OpContext op_context = DescribePoolingOp(
1478 "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1479 return estimator_.PredictCosts(op_context);
1480 };
1481
1482 {
1483 // Typical 3x3 window with 2x2 stride.
1484 auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1485 EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1486 EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1487 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1488 EXPECT_EQ(costs.num_ops_total, 1);
1489 EXPECT_FALSE(costs.inaccurate);
1490 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1491 EXPECT_EQ(costs.temporary_memory, 0);
1492 EXPECT_EQ(costs.persistent_memory, 0);
1493 }
1494 {
1495 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1496 auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1497 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1498 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1499 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1500 EXPECT_EQ(1, costs.num_ops_total);
1501 EXPECT_FALSE(costs.inaccurate);
1502 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1503 }
1504 {
1505 // 2x2 window with 3x3 stride.
1506 auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1507 EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1508 EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1509 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1510 EXPECT_EQ(1, costs.num_ops_total);
1511 EXPECT_FALSE(costs.inaccurate);
1512 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1513 }
1514 }
1515
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1516 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1517 auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1518 const int k, const int s,
1519 const string& padding) -> Costs {
1520 OpContext op_context =
1521 DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1522 {1, s, s, 1}, "NHWC", padding);
1523 return estimator_.PredictCosts(op_context);
1524 };
1525
1526 {
1527 // Typical 3xz3 window with 2x2 stride.
1528 auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1529 EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1530 EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1531 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1532 EXPECT_EQ(costs.num_ops_total, 1);
1533 EXPECT_FALSE(costs.inaccurate);
1534 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1535 EXPECT_EQ(costs.temporary_memory, 0);
1536 EXPECT_EQ(costs.persistent_memory, 0);
1537 }
1538 {
1539 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1540 auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1541 EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1542 EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1543 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1544 EXPECT_EQ(1, costs.num_ops_total);
1545 EXPECT_FALSE(costs.inaccurate);
1546 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1547 }
1548 {
1549 // 2x2 window with 3x3 stride.
1550 auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1551 EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1552 EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1553 EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1554 EXPECT_EQ(1, costs.num_ops_total);
1555 EXPECT_FALSE(costs.inaccurate);
1556 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1557 }
1558 }
1559
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1560 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1561 auto predict_fused_bn = [this](const int n, const int in, const int c,
1562 const bool is_training) -> Costs {
1563 OpContext op_context = DescribeFusedBatchNorm(
1564 is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1565 return estimator_.PredictCosts(op_context);
1566 };
1567
1568 {
1569 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1570 EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1571 EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1572 EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1573 EXPECT_EQ(costs.num_ops_total, 1);
1574 EXPECT_FALSE(costs.inaccurate);
1575 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1576 EXPECT_EQ(costs.temporary_memory, 0);
1577 EXPECT_EQ(costs.persistent_memory, 0);
1578 }
1579
1580 {
1581 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1582 EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1583 EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1584 EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1585 EXPECT_EQ(1, costs.num_ops_total);
1586 EXPECT_FALSE(costs.inaccurate);
1587 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1588 }
1589
1590 {
1591 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1592 EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1593 EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1594 EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1595 EXPECT_EQ(1, costs.num_ops_total);
1596 EXPECT_FALSE(costs.inaccurate);
1597 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1598 }
1599
1600 {
1601 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1602 EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1603 EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1604 EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1605 EXPECT_FALSE(costs.inaccurate);
1606 EXPECT_EQ(1, costs.num_ops_total);
1607 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1608 }
1609 }
1610
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1611 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1612 auto predict_fused_bn_grad = [this](const int n, const int in,
1613 const int c) -> Costs {
1614 OpContext op_context = DescribeFusedBatchNorm(
1615 /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1616 return estimator_.PredictCosts(op_context);
1617 };
1618
1619 {
1620 auto costs = predict_fused_bn_grad(10, 20, 96);
1621 EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1622 EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1623 EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1624 EXPECT_EQ(costs.num_ops_total, 1);
1625 EXPECT_FALSE(costs.inaccurate);
1626 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1627 EXPECT_EQ(costs.temporary_memory, 0);
1628 EXPECT_EQ(costs.persistent_memory, 0);
1629 }
1630
1631 {
1632 auto costs = predict_fused_bn_grad(128, 7, 384);
1633 EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1634 EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1635 EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1636 EXPECT_EQ(1, costs.num_ops_total);
1637 EXPECT_FALSE(costs.inaccurate);
1638 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1639 }
1640 }
1641
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1642 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1643 {
1644 TensorShapeProto x;
1645 x.set_unknown_rank(true);
1646 bool unknown_shapes = false;
1647 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1648 EXPECT_TRUE(unknown_shapes);
1649 ExpectTensorShape({1, 1, 1, 1}, y);
1650 }
1651
1652 {
1653 TensorShapeProto x;
1654 x.set_unknown_rank(false);
1655 bool unknown_shapes = false;
1656 TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1657 EXPECT_FALSE(unknown_shapes);
1658 ExpectTensorShape({1}, y);
1659 }
1660
1661 {
1662 TensorShapeProto x;
1663 x.set_unknown_rank(false);
1664 bool unknown_shapes = false;
1665 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1666 EXPECT_FALSE(unknown_shapes);
1667 ExpectTensorShape({1, 1}, y);
1668 }
1669
1670 {
1671 TensorShapeProto x;
1672 x.set_unknown_rank(false);
1673 x.add_dim()->set_size(10);
1674 x.add_dim()->set_size(20);
1675 bool unknown_shapes = false;
1676 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1677 EXPECT_FALSE(unknown_shapes);
1678 ExpectTensorShape({10, 20}, y);
1679
1680 unknown_shapes = false;
1681 TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1682 EXPECT_TRUE(unknown_shapes);
1683 EXPECT_EQ(4, z.dim_size());
1684 ExpectTensorShape({10, 20, 1, 1}, z);
1685 }
1686
1687 {
1688 TensorShapeProto x;
1689 x.set_unknown_rank(false);
1690 x.add_dim()->set_size(10);
1691 x.add_dim()->set_size(20);
1692 x.add_dim()->set_size(-1);
1693 x.add_dim()->set_size(20);
1694 bool unknown_shapes = false;
1695 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1696 EXPECT_TRUE(unknown_shapes);
1697 ExpectTensorShape({10, 20, 1, 20}, y);
1698 }
1699
1700 {
1701 TensorShapeProto x;
1702 x.set_unknown_rank(false);
1703 x.add_dim()->set_size(10);
1704 x.add_dim()->set_size(20);
1705 x.add_dim()->set_size(30);
1706 x.add_dim()->set_size(20);
1707 bool unknown_shapes = false;
1708 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1709 EXPECT_TRUE(unknown_shapes);
1710 ExpectTensorShape({10, 20}, y);
1711 }
1712 }
1713
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1714 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1715 TestOpLevelCostEstimator estimator;
1716
1717 // Compute limited.
1718 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1719 /*gb_per_sec=*/1));
1720 estimator.SetComputeMemoryOverlap(true);
1721 auto cost = estimator.PredictCosts(
1722 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1723 EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1724 EXPECT_EQ(cost.execution_time, cost.compute_time);
1725
1726 estimator.SetComputeMemoryOverlap(false);
1727 cost = estimator.PredictCosts(
1728 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1729 EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1730 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1731 cost.intermediate_memory_time);
1732
1733 // Memory limited.
1734 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1735 /*gb_per_sec=*/1));
1736 estimator.SetComputeMemoryOverlap(true);
1737 cost = estimator.PredictCosts(
1738 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1739 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1740 EXPECT_EQ(cost.execution_time, cost.memory_time);
1741
1742 estimator.SetComputeMemoryOverlap(false);
1743 cost = estimator.PredictCosts(
1744 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1745 EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1746 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1747 cost.intermediate_memory_time);
1748
1749 // Intermediate memory bandwidth limited.
1750 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1751 /*gb_per_sec=*/9999,
1752 /*intermediate_read_gb_per_sec=*/1,
1753 /*intermediate_write_gb_per_sec=*/1));
1754 estimator.SetComputeMemoryOverlap(true);
1755 cost = estimator.PredictCosts(
1756 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1757 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1758 EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1759
1760 estimator.SetComputeMemoryOverlap(false);
1761 cost = estimator.PredictCosts(
1762 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1763 EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1764 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1765 cost.intermediate_memory_time);
1766 }
1767
TEST_F(OpLevelCostEstimatorTest,Einsum)1768 TEST_F(OpLevelCostEstimatorTest, Einsum) {
1769 { // Test a simple matrix multiplication.
1770 auto cost = PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"));
1771 EXPECT_EQ(Costs::Duration(104000), cost.execution_time);
1772 EXPECT_EQ(Costs::Duration(100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1773 cost.compute_time);
1774 EXPECT_EQ(Costs::Duration(4000), cost.memory_time);
1775 EXPECT_EQ(cost.num_ops_total, 1);
1776 EXPECT_FALSE(cost.inaccurate);
1777 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1778 EXPECT_EQ(cost.temporary_memory, 0);
1779 EXPECT_EQ(cost.persistent_memory, 0);
1780
1781 // Einsums and XlaEinsums should be estimated similarly.
1782 EXPECT_EQ(PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1783 .execution_time,
1784 PredictCosts(DescribeXlaEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1785 .execution_time);
1786 }
1787 { // Test a simple batch matrix multiplication.
1788 auto cost = PredictCosts(
1789 DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"));
1790 EXPECT_EQ(Costs::Duration(25 * 104000), cost.execution_time);
1791 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1792 cost.compute_time);
1793 EXPECT_EQ(Costs::Duration(25 * 4000), cost.memory_time);
1794 EXPECT_EQ(1, cost.num_ops_total);
1795 EXPECT_FALSE(cost.inaccurate);
1796 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1797
1798 // Einsums and XlaEinsums should be estimated similarly.
1799 EXPECT_EQ(PredictCosts(
1800 DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"))
1801 .execution_time,
1802 PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50, 25},
1803 "Bik,jkB->Bij"))
1804 .execution_time);
1805 }
1806 { // Test multiple batch dimensions.
1807 auto cost = PredictCosts(DescribeEinsum(
1808 {25, 16, 100, 50}, {16, 100, 50, 25}, "BNik,NjkB->BNij"));
1809 EXPECT_EQ(Costs::Duration(16 * 25 * 104000), cost.execution_time);
1810 EXPECT_EQ(
1811 Costs::Duration(16 * 25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1812 cost.compute_time);
1813 EXPECT_EQ(Costs::Duration(16 * 25 * 4000), cost.memory_time);
1814 EXPECT_EQ(1, cost.num_ops_total);
1815 EXPECT_FALSE(cost.inaccurate);
1816 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1817
1818 // Einsums and XlaEinsums should be estimated similarly.
1819 EXPECT_EQ(
1820 PredictCosts(DescribeEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1821 "BNik,NjkB->BNij"))
1822 .execution_time,
1823 PredictCosts(DescribeXlaEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1824 "BNik,NjkB->BNij"))
1825 .execution_time);
1826 }
1827 { // Test multiple M dimensions.
1828 auto cost =
1829 PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"));
1830 EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1831 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1832 cost.compute_time);
1833 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1834 EXPECT_EQ(1, cost.num_ops_total);
1835 EXPECT_FALSE(cost.inaccurate);
1836 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1837
1838 // Einsums and XlaEinsums should be estimated similarly.
1839 EXPECT_EQ(
1840 PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1841 .execution_time,
1842 PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1843 .execution_time);
1844 }
1845 { // Test multiple N dimensions.
1846 auto cost =
1847 PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"));
1848 EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1849 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1850 cost.compute_time);
1851 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1852 EXPECT_EQ(1, cost.num_ops_total);
1853 EXPECT_FALSE(cost.inaccurate);
1854 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1855
1856 // Einsums and XlaEinsums should be estimated similarly.
1857 EXPECT_EQ(
1858 PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1859 .execution_time,
1860 PredictCosts(DescribeXlaEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1861 .execution_time);
1862 }
1863 { // Test multiple contracting dimensions.
1864 auto cost = PredictCosts(
1865 DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"));
1866 EXPECT_EQ(Costs::Duration(2600000), cost.execution_time);
1867 EXPECT_EQ(Costs::Duration(100 * 50 * 25 * 100 * 2 / (1000 * 10 * 1e-3)),
1868 cost.compute_time);
1869 EXPECT_EQ(Costs::Duration(100000), cost.memory_time);
1870 EXPECT_EQ(1, cost.num_ops_total);
1871 EXPECT_FALSE(cost.inaccurate);
1872 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1873
1874 // Einsums and XlaEinsums should be estimated similarly.
1875 EXPECT_EQ(PredictCosts(
1876 DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"))
1877 .execution_time,
1878 PredictCosts(DescribeXlaEinsum({100, 50, 25}, {100, 50, 25},
1879 "ikl,jkl->ij"))
1880 .execution_time);
1881 }
1882 { // Test a simple matrix transpose.
1883 auto cost = PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji"));
1884 EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
1885 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1886 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1887 EXPECT_EQ(1, cost.num_ops_total);
1888 EXPECT_TRUE(cost.inaccurate);
1889 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1890
1891 // Einsums and XlaEinsums should be estimated similarly.
1892 EXPECT_EQ(
1893 PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji")).execution_time,
1894 PredictCosts(DescribeXlaEinsum({100, 50}, {}, "ij->ji"))
1895 .execution_time);
1896 }
1897 { // Test a malformed Einsum equation: Mismatch between shapes and equation.
1898 auto cost =
1899 PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"));
1900 EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1901 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1902 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1903 EXPECT_EQ(1, cost.num_ops_total);
1904 EXPECT_TRUE(cost.inaccurate);
1905 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1906
1907 // Einsums and XlaEinsums should be estimated similarly.
1908 EXPECT_EQ(
1909 PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1910 .execution_time,
1911 PredictCosts(DescribeXlaEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1912 .execution_time);
1913
1914 cost = PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"));
1915 EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1916 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1917 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1918 EXPECT_EQ(1, cost.num_ops_total);
1919 EXPECT_TRUE(cost.inaccurate);
1920 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1921
1922 // Einsums and XlaEinsums should be estimated similarly.
1923 EXPECT_EQ(
1924 PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1925 .execution_time,
1926 PredictCosts(DescribeXlaEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1927 .execution_time);
1928 }
1929 { // Test an unsupported Einsum: ellipsis
1930 auto cost = PredictCosts(DescribeEinsum(
1931 {100, 50, 25, 16}, {50, 100, 32, 12}, "ik...,kl...->il..."));
1932 EXPECT_EQ(Costs::Duration(1568000), cost.execution_time);
1933 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1934 EXPECT_EQ(Costs::Duration(1568000), cost.memory_time);
1935 EXPECT_EQ(1, cost.num_ops_total);
1936 EXPECT_TRUE(cost.inaccurate);
1937 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1938
1939 // Einsums and XlaEinsums should be estimated similarly.
1940 EXPECT_EQ(
1941 PredictCosts(DescribeEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1942 "ik...,kl...->il..."))
1943 .execution_time,
1944 PredictCosts(DescribeXlaEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1945 "ik...,kl...->il..."))
1946 .execution_time);
1947 }
1948 { // Test a malformed/unsupported Einsum: repeated indices
1949 auto cost =
1950 PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"));
1951 EXPECT_EQ(Costs::Duration(202000), cost.execution_time);
1952 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1953 EXPECT_EQ(Costs::Duration(202000), cost.memory_time);
1954 EXPECT_EQ(1, cost.num_ops_total);
1955 EXPECT_TRUE(cost.inaccurate);
1956 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1957
1958 // Einsums and XlaEinsums should be estimated similarly.
1959 EXPECT_EQ(
1960 PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
1961 .execution_time,
1962 PredictCosts(DescribeXlaEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
1963 .execution_time);
1964 }
1965 { // Test missing shapes.
1966 auto cost = PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"));
1967 EXPECT_EQ(Costs::Duration(3020), cost.execution_time);
1968 EXPECT_EQ(Costs::Duration(1 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1969 cost.compute_time);
1970 EXPECT_EQ(Costs::Duration(2020), cost.memory_time);
1971 EXPECT_EQ(1, cost.num_ops_total);
1972 EXPECT_TRUE(cost.inaccurate);
1973 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1974
1975 // Einsums and XlaEinsums should be estimated similarly.
1976 EXPECT_EQ(PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
1977 .execution_time,
1978 PredictCosts(DescribeXlaEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
1979 .execution_time);
1980 }
1981 }
1982
TEST_F(OpLevelCostEstimatorTest,PredictResourceVariableOps)1983 TEST_F(OpLevelCostEstimatorTest, PredictResourceVariableOps) {
1984 TestOpLevelCostEstimator estimator;
1985 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1));
1986
1987 {
1988 OpContext op_context;
1989 op_context.op_info.set_op("AssignVariableOp");
1990 DescribeDummyTensor(op_context.op_info.add_inputs());
1991 DescribeTensor1D(100, op_context.op_info.add_inputs());
1992 auto cost = estimator.PredictCosts(op_context);
1993 EXPECT_EQ(Costs::Duration(400), cost.memory_time);
1994 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1995 EXPECT_EQ(Costs::Duration(400), cost.execution_time);
1996 EXPECT_FALSE(cost.inaccurate);
1997 EXPECT_EQ(cost.temporary_memory, 0);
1998 EXPECT_EQ(cost.persistent_memory, 0);
1999 }
2000
2001 {
2002 OpContext op_context;
2003 op_context.op_info.set_op("AssignSubVariableOp");
2004 DescribeDummyTensor(op_context.op_info.add_inputs());
2005 DescribeTensor1D(100, op_context.op_info.add_inputs());
2006 auto cost = estimator.PredictCosts(op_context);
2007 EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2008 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
2009 EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2010 EXPECT_FALSE(cost.inaccurate);
2011 }
2012 }
2013
TEST_F(OpLevelCostEstimatorTest,AddNExecutionTime)2014 TEST_F(OpLevelCostEstimatorTest, AddNExecutionTime) {
2015 OpContext op_context;
2016 SetCpuDevice(&op_context.op_info);
2017 op_context.op_info.set_op("AddN");
2018
2019 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2020 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2021 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2022
2023 auto cost = PredictCosts(op_context);
2024 EXPECT_EQ(Costs::Duration(1200), cost.memory_time);
2025 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
2026 EXPECT_EQ(Costs::Duration(1400), cost.execution_time);
2027 EXPECT_EQ(cost.num_ops_total, 1);
2028 EXPECT_FALSE(cost.inaccurate);
2029 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2030 EXPECT_EQ(cost.temporary_memory, 0);
2031 EXPECT_EQ(cost.persistent_memory, 0);
2032 }
2033
TEST_F(OpLevelCostEstimatorTest,IdentityOpExecutionTime)2034 TEST_F(OpLevelCostEstimatorTest, IdentityOpExecutionTime) {
2035 std::vector<std::string> identity_ops = {
2036 "_Recv", "_Send", "BitCast", "Identity",
2037 "Enter", "Exit", "IdentityN", "Merge",
2038 "NextIteration", "Placeholder", "PreventGradient", "RefIdentity",
2039 "Reshape", "StopGradient", "Switch"};
2040
2041 const int kTensorSize = 1000;
2042 for (auto identity_op : identity_ops) {
2043 OpContext op_context = DescribeUnaryOp(identity_op, kTensorSize);
2044
2045 const int kExpectedMemoryTime = 0;
2046 const int kExpectedComputeTime = 1;
2047
2048 auto cost = PredictCosts(op_context);
2049 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2050 EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2051 EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2052 cost.execution_time);
2053 EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2054 EXPECT_EQ(cost.num_ops_total, 1);
2055 EXPECT_FALSE(cost.inaccurate);
2056 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2057 EXPECT_EQ(cost.temporary_memory, 0);
2058 EXPECT_EQ(cost.persistent_memory, 0);
2059 }
2060 }
2061
TEST_F(OpLevelCostEstimatorTest,PureMemoryOpExecutionTime)2062 TEST_F(OpLevelCostEstimatorTest, PureMemoryOpExecutionTime) {
2063 std::vector<std::string> reshape_ops = {
2064 "ConcatV2", "DataFormatVecPermute",
2065 "DepthToSpace", "ExpandDims",
2066 "Fill", "OneHot",
2067 "Pack", "Range",
2068 "SpaceToDepth", "Split",
2069 "Squeeze", "Transpose",
2070 "Tile", "Unpack"};
2071
2072 const int kTensorSize = 1000;
2073 for (auto reshape_op : reshape_ops) {
2074 OpContext op_context = DescribeUnaryOp(reshape_op, kTensorSize);
2075
2076 const int kExpectedMemoryTime = 800;
2077 const int kExpectedComputeTime = 0;
2078
2079 auto cost = PredictCosts(op_context);
2080 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2081 EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2082 EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2083 cost.execution_time);
2084 EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2085 EXPECT_EQ(cost.num_ops_total, 1);
2086 EXPECT_FALSE(cost.inaccurate);
2087 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2088 EXPECT_EQ(cost.temporary_memory, 0);
2089 EXPECT_EQ(cost.persistent_memory, 0);
2090 }
2091 }
2092
TEST_F(OpLevelCostEstimatorTest,ResizeBilinearExecutionTime)2093 TEST_F(OpLevelCostEstimatorTest, ResizeBilinearExecutionTime) {
2094 const int kImageDim = 255;
2095 const int kChannelSize = 10;
2096 const int kComputeLerpCost = 9;
2097 {
2098 OpContext op_context;
2099 SetCpuDevice(&op_context.op_info);
2100 op_context.op_info.set_op("ResizeBilinear");
2101 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2102 op_context.op_info.add_inputs());
2103 // Test with no output.
2104 auto cost = PredictCosts(op_context);
2105 ExpectZeroCost(cost);
2106 op_context.op_info.clear_inputs();
2107
2108 DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2109 // Test with no input.
2110 cost = PredictCosts(op_context);
2111 ExpectZeroCost(cost);
2112 }
2113 {
2114 // Test with size 0 output.
2115 OpContext op_context;
2116 SetCpuDevice(&op_context.op_info);
2117 op_context.op_info.set_op("ResizeBilinear");
2118
2119 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2120 op_context.op_info.add_inputs());
2121 const int kExpectedMemoryTime = kImageDim * kImageDim * 4;
2122 DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2123
2124 // As the half_pixel_centers attr was not set, cost should be inaccurate
2125 // with 0 compute time.
2126 auto cost = PredictCosts(op_context);
2127 EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2128 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2129 EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2130 EXPECT_TRUE(cost.inaccurate);
2131 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2132 EXPECT_EQ(cost.temporary_memory, 0);
2133 EXPECT_EQ(cost.persistent_memory, 0);
2134
2135 AttrValue half_pixel_centers;
2136 half_pixel_centers.set_b(false);
2137 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2138 half_pixel_centers;
2139 cost = PredictCosts(op_context);
2140 // Compute time depends only on output size, so compute time is 0.
2141 EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2142 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2143 EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2144 EXPECT_FALSE(cost.inaccurate);
2145 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2146 }
2147
2148 // Test with non-zero output size.
2149 const int kOutputImageDim = 100;
2150 OpContext op_context;
2151 SetCpuDevice(&op_context.op_info);
2152 op_context.op_info.set_op("ResizeBilinear");
2153 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2154 op_context.op_info.add_inputs());
2155 DescribeTensor4D(1, kOutputImageDim, kOutputImageDim, kChannelSize,
2156 op_context.op_info.add_outputs());
2157 const int kExpectedMemoryTime =
2158 (kImageDim * kImageDim + kOutputImageDim * kOutputImageDim) * 4;
2159
2160 {
2161 // Cost of calculating weights without using half_pixel_centers.
2162 AttrValue half_pixel_centers;
2163 half_pixel_centers.set_b(false);
2164 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2165 half_pixel_centers;
2166 const int kInterpWeightCost = 10;
2167 const int num_ops =
2168 kInterpWeightCost * (kOutputImageDim * 2) +
2169 kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2170 const int expected_compute_time = std::ceil(
2171 num_ops /
2172 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2173
2174 const auto cost = PredictCosts(op_context);
2175 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2176 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2177 EXPECT_EQ(cost.execution_time,
2178 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2179 EXPECT_FALSE(cost.inaccurate);
2180 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2181 }
2182
2183 {
2184 // Cost of calculating weights using half_pixel_centers.
2185 AttrValue half_pixel_centers;
2186 half_pixel_centers.set_b(true);
2187 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2188 half_pixel_centers;
2189 const int kInterpWeightCost = 12;
2190 const int num_ops =
2191 kInterpWeightCost * (kOutputImageDim * 2) +
2192 kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2193 const int expected_compute_time = std::ceil(
2194 num_ops /
2195 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2196
2197 const auto cost = PredictCosts(op_context);
2198 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2199 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2200 EXPECT_EQ(cost.execution_time,
2201 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2202 EXPECT_FALSE(cost.inaccurate);
2203 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2204 }
2205
2206 {
2207 // Cost with very large tensor.
2208 op_context.op_info.clear_outputs();
2209 // Number of elements in tensor exceeds 2^32.
2210 constexpr int64 kLargeOutputImageDim = 40000;
2211 DescribeTensor4D(1, kLargeOutputImageDim, kLargeOutputImageDim,
2212 kChannelSize, op_context.op_info.add_outputs());
2213 const int64 kInterpWeightCost = 12;
2214 // Using half_pixel_centers.
2215 AttrValue half_pixel_centers;
2216 half_pixel_centers.set_b(true);
2217 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2218 half_pixel_centers;
2219
2220 const int64 num_ops =
2221 kInterpWeightCost * (kLargeOutputImageDim * 2) +
2222 kComputeLerpCost *
2223 (kLargeOutputImageDim * kLargeOutputImageDim * kChannelSize);
2224 const int64 expected_compute_time = std::ceil(
2225 num_ops /
2226 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2227
2228 const int64 expected_memory_time =
2229 (kImageDim * kImageDim + kLargeOutputImageDim * kLargeOutputImageDim) *
2230 4;
2231
2232 const auto cost = PredictCosts(op_context);
2233 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2234 EXPECT_EQ(cost.memory_time, Costs::Duration(expected_memory_time));
2235 EXPECT_EQ(cost.execution_time,
2236 Costs::Duration(expected_memory_time + expected_compute_time));
2237 EXPECT_FALSE(cost.inaccurate);
2238 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2239 }
2240 }
2241
TEST_F(OpLevelCostEstimatorTest,CropAndResizeExecutionTime)2242 TEST_F(OpLevelCostEstimatorTest, CropAndResizeExecutionTime) {
2243 const int kImageDim = 255;
2244 const int kChannelSize = 10;
2245 const int kOutputImageDim = 100;
2246 const int kNumBoxes = 10;
2247 const int kOutputElements =
2248 kNumBoxes * kOutputImageDim * kOutputImageDim * kChannelSize;
2249 OpContext op_context;
2250 SetCpuDevice(&op_context.op_info);
2251 op_context.op_info.set_op("CropAndResize");
2252 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2253 op_context.op_info.add_inputs());
2254 DescribeArbitraryRankInput({kNumBoxes, 4}, DT_INT64, &op_context.op_info);
2255 DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize,
2256 op_context.op_info.add_outputs());
2257
2258 // Note this is time [ns, default in Duration in Costs], not bytes;
2259 // whereas memory bandwidth from SetCpuDevice() is 10GB/s.
2260 const int kExpectedMemoryTime =
2261 (kImageDim * kImageDim * 4 + // input image in float.
2262 kNumBoxes * 4 * 8 / 10 + // boxes (kNumBoxes x 4) in int64.
2263 kNumBoxes * kOutputImageDim * kOutputImageDim * 4); // output in float.
2264 // Note that input image and output image has kChannelSize dim, which is 10,
2265 // hence, no need to divide it by 10 (bandwidth).
2266
2267 {
2268 // Cost of CropAndResize with bilinear interpolation.
2269 AttrValue method;
2270 method.set_s("bilinear");
2271 (*op_context.op_info.mutable_attr())["method"] = method;
2272 int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2273 4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2274 3 * kNumBoxes * kOutputImageDim +
2275 3 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2276 13 * kOutputElements;
2277 const int expected_compute_time = std::ceil(
2278 num_ops /
2279 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2280
2281 const auto cost = PredictCosts(op_context);
2282 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2283 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2284 EXPECT_EQ(cost.execution_time,
2285 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2286 EXPECT_FALSE(cost.inaccurate);
2287 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2288 }
2289
2290 {
2291 // Cost of CropAndResize when nearest pixel is taken.
2292 AttrValue method;
2293 method.set_s("nearest");
2294 (*op_context.op_info.mutable_attr())["method"] = method;
2295 int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2296 4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2297 2 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2298 kOutputElements;
2299 const int expected_compute_time = std::ceil(
2300 num_ops /
2301 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2302
2303 const auto cost = PredictCosts(op_context);
2304 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2305 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2306 EXPECT_EQ(cost.execution_time,
2307 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2308 EXPECT_FALSE(cost.inaccurate);
2309 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2310 }
2311 }
2312
2313 } // end namespace grappler
2314 } // end namespace tensorflow
2315