1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "absl/algorithm/container.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/image_ops.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def_builder.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/kernels/conv_ops_gpu.h"
30 #include "tensorflow/core/kernels/ops_testutil.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
35 #include "tensorflow/core/public/session.h"
36 
37 namespace tensorflow {
38 
39 #if GOOGLE_CUDA
40 
41 struct ConvParametersPeer {
42   template <typename T>
ShouldIncludeWinogradNonfusedAlgoPreCudnn7tensorflow::ConvParametersPeer43   bool ShouldIncludeWinogradNonfusedAlgoPreCudnn7() {
44     return params.ShouldIncludeWinogradNonfusedAlgoPreCudnn7<T>();
45   }
46 
47   ConvParameters params;
48 };
49 
TEST(ConvParameters,WinogradNonfusedAlgoSize)50 TEST(ConvParameters, WinogradNonfusedAlgoSize) {
51   ConvParametersPeer conv_params_small = {{
52       1,            // batch
53       32,           // in_depths
54       {{300,        // in_rows
55         300}},      // in_cols
56       FORMAT_NCHW,  // compute_data_format
57       128,          // out_depths
58       {{3,          // filter_rows
59         3}},        // filter_cols
60       {{1,          // dilation_rows
61         1}},        // dilation_cols
62       {{1,          // stride_rows
63         1}},        // stride_cols
64       {{0,          // padding_rows
65         0}},        // padding_cols
66       DT_FLOAT,     // tensor datatype
67       0,            // device_id
68   }};
69   EXPECT_TRUE(
70       conv_params_small.ShouldIncludeWinogradNonfusedAlgoPreCudnn7<float>());
71 
72   ConvParametersPeer conv_params_large = {{
73       1,            // batch
74       128,          // in_depths
75       {{300,        // in_rows
76         300}},      // in_cols
77       FORMAT_NCHW,  // compute_data_format
78       768,          // out_depths
79       {{3,          // filter_rows
80         3}},        // filter_cols
81       {{1,          // dilation_rows
82         1}},        // dilation_cols
83       {{1,          // stride_rows
84         1}},        // stride_cols
85       {{0,          // padding_rows
86         0}},        // padding_cols
87       DT_FLOAT,     // tensor datatype
88       0,            // device_id
89   }};
90   EXPECT_FALSE(
91       conv_params_large.ShouldIncludeWinogradNonfusedAlgoPreCudnn7<float>());
92 }
93 
94 #endif  // GOOGLE_CUDA
95 
96 class FusedResizePadConvOpTest : public OpsTestBase {
97  protected:
98   template <typename T>
HandwrittenConv(DataType dtype)99   void HandwrittenConv(DataType dtype) {
100     const int stride = 1;
101     TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
102                      .Input(FakeInput(dtype))
103                      .Input(FakeInput(DT_INT32))
104                      .Input(FakeInput(DT_INT32))
105                      .Input(FakeInput(dtype))
106                      .Attr("T", dtype)
107                      .Attr("resize_align_corners", false)
108                      .Attr("mode", "REFLECT")
109                      .Attr("strides", {1, stride, stride, 1})
110                      .Attr("padding", "SAME")
111                      .Finalize(node_def()));
112     TF_EXPECT_OK(InitOp());
113     const int depth = 1;
114     const int image_width = 4;
115     const int image_height = 3;
116     const int image_batch_count = 1;
117     // The image matrix is:
118     // |  1 |  2 |  3 |  4 |
119     // |  5 |  6 |  7 |  8 |
120     // |  9 | 10 | 11 | 12 |
121     Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
122     test::FillValues<T>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
123 
124     // The filter matrix is:
125     // | 1 | 4 | 7 |
126     // | 2 | 5 | 8 |
127     // | 3 | 6 | 9 |
128     const int filter_size = 3;
129     const int filter_count = 1;
130     Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
131     test::FillValues<T>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
132 
133     const int resized_width = image_width;
134     const int resized_height = image_height;
135 
136     const int top_padding = 0;
137     const int bottom_padding = 0;
138     const int left_padding = 0;
139     const int right_padding = 0;
140 
141     AddInputFromArray<T>(image.shape(), image.flat<T>());
142     AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
143     AddInputFromArray<int32>(
144         TensorShape({4, 2}),
145         {0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
146     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
147     TF_ASSERT_OK(RunOpKernel());
148 
149     // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
150     // the input set to zero because we're using the 'SAME' padding mode.
151     // The calculations behind the expected output are:
152     // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
153     // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
154     // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
155     // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
156     // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
157     // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
158     // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
159     // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
160     // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
161     // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
162     // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
163     // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
164     // This means we should end up with this matrix:
165     // |  105  |  150  |  183  |   95  |
166     // |  235  |  312  |  357  |  178  |
167     // |  187  |  234  |  261  |  121  |
168     const int expected_width = image_width;
169     const int expected_height = image_height * filter_count;
170     Tensor expected(dtype, TensorShape({image_batch_count, expected_height,
171                                         expected_width, filter_count}));
172     test::FillValues<T>(
173         &expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
174     const Tensor& output = *GetOutput(0);
175     test::ExpectTensorNear<T>(expected, output, 1e-5);
176   }
177 
178   template <typename T>
CompareFusedAndSeparate(int input_width,int input_height,int input_depth,int resize_width,int resize_height,int y_padding,int x_padding,int filter_size,int filter_count,bool resize_align_corners,const string & pad_mode,int stride,const string & padding,DataType dtype)179   void CompareFusedAndSeparate(int input_width, int input_height,
180                                int input_depth, int resize_width,
181                                int resize_height, int y_padding, int x_padding,
182                                int filter_size, int filter_count,
183                                bool resize_align_corners,
184                                const string& pad_mode, int stride,
185                                const string& padding, DataType dtype) {
186     Scope root = tensorflow::Scope::NewRootScope();
187     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
188 
189     Tensor input_data(DT_FLOAT,
190                       TensorShape({1, input_height, input_width, input_depth}));
191     test::FillIota<float>(&input_data, 1.0f);
192     Output input =
193         Const(root.WithOpName("input"), Input::Initializer(input_data));
194     Output casted_input = Cast(root.WithOpName("casted_input"), input, dtype);
195 
196     Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
197                                               input_depth, filter_count}));
198     test::FillIota<float>(&filter_data, 1.0f);
199     Output filter =
200         Const(root.WithOpName("filter"), Input::Initializer(filter_data));
201     Output casted_filter =
202         Cast(root.WithOpName("casted_filter"), filter, dtype);
203 
204     Output resize_size =
205         Const(root.WithOpName("resize_size"), {resize_height, resize_width});
206     Output resize =
207         ResizeBilinear(root.WithOpName("resize"), input, resize_size,
208                        ResizeBilinear::AlignCorners(resize_align_corners));
209     // Bilinear resize only output float, cast it to dtype to match the input.
210     Output casted_resize = Cast(root.WithOpName("cast"), resize, dtype);
211     Output paddings =
212         Const(root.WithOpName("paddings"),
213               {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
214     Output mirror_pad = MirrorPad(root.WithOpName("mirror_pad"), casted_resize,
215                                   paddings, pad_mode);
216     Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, casted_filter,
217                          {1, stride, stride, 1}, padding);
218 
219     Output fused_conv = FusedResizeAndPadConv2D(
220         root.WithOpName("fused_conv"), casted_input, resize_size, paddings,
221         casted_filter, pad_mode, {1, stride, stride, 1}, padding,
222         FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
223 
224     tensorflow::GraphDef graph;
225     TF_ASSERT_OK(root.ToGraphDef(&graph));
226 
227     std::unique_ptr<tensorflow::Session> session(
228         tensorflow::NewSession(tensorflow::SessionOptions()));
229     TF_ASSERT_OK(session->Create(graph));
230 
231     std::vector<Tensor> unfused_tensors;
232     TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
233 
234     std::vector<Tensor> fused_tensors;
235     TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
236 
237     test::ExpectClose(unfused_tensors[0], fused_tensors[0]);
238   }
239 
240   template <typename T>
CompareFusedPadOnlyAndSeparate(int input_width,int input_height,int input_depth,int y_padding,int x_padding,int filter_size,int filter_count,const string & pad_mode,int stride,const string & padding,DataType dtype)241   void CompareFusedPadOnlyAndSeparate(int input_width, int input_height,
242                                       int input_depth, int y_padding,
243                                       int x_padding, int filter_size,
244                                       int filter_count, const string& pad_mode,
245                                       int stride, const string& padding,
246                                       DataType dtype) {
247     Scope root = tensorflow::Scope::NewRootScope();
248     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
249 
250     Tensor input_data(DT_FLOAT,
251                       TensorShape({1, input_height, input_width, input_depth}));
252     test::FillIota<float>(&input_data, 1.0f);
253     Output input =
254         Const(root.WithOpName("input"), Input::Initializer(input_data));
255     Output casted_input = Cast(root.WithOpName("casted_input"), input, dtype);
256 
257     Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
258                                               input_depth, filter_count}));
259     test::FillIota<float>(&filter_data, 1.0f);
260     Output filter =
261         Const(root.WithOpName("filter"), Input::Initializer(filter_data));
262     Output casted_filter =
263         Cast(root.WithOpName("casted_filter"), filter, dtype);
264 
265     Output paddings =
266         Const(root.WithOpName("paddings"),
267               {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
268     Output mirror_pad = MirrorPad(root.WithOpName("mirror_pad"), casted_input,
269                                   paddings, pad_mode);
270     Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, casted_filter,
271                          {1, stride, stride, 1}, padding);
272 
273     Output fused_conv = FusedPadConv2D(
274         root.WithOpName("fused_conv"), casted_input, paddings, casted_filter,
275         pad_mode, {1, stride, stride, 1}, padding);
276 
277     tensorflow::GraphDef graph;
278     TF_ASSERT_OK(root.ToGraphDef(&graph));
279 
280     std::unique_ptr<tensorflow::Session> session(
281         tensorflow::NewSession(tensorflow::SessionOptions()));
282     TF_ASSERT_OK(session->Create(graph));
283 
284     std::vector<Tensor> unfused_tensors;
285     TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
286 
287     std::vector<Tensor> fused_tensors;
288     TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
289 
290     test::ExpectClose(unfused_tensors[0], fused_tensors[0]);
291   }
292 };
293 
TEST_F(FusedResizePadConvOpTest,HandwrittenConvHalf)294 TEST_F(FusedResizePadConvOpTest, HandwrittenConvHalf) {
295   HandwrittenConv<Eigen::half>(DT_HALF);
296 }
297 
TEST_F(FusedResizePadConvOpTest,HandwrittenConvFloat)298 TEST_F(FusedResizePadConvOpTest, HandwrittenConvFloat) {
299   HandwrittenConv<float>(DT_FLOAT);
300 }
301 
TEST_F(FusedResizePadConvOpTest,HandwrittenConvDouble)302 TEST_F(FusedResizePadConvOpTest, HandwrittenConvDouble) {
303   HandwrittenConv<double>(DT_DOUBLE);
304 }
305 
TEST_F(FusedResizePadConvOpTest,IdentityComparativeHalf)306 TEST_F(FusedResizePadConvOpTest, IdentityComparativeHalf) {
307   CompareFusedAndSeparate<Eigen::half>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
308                                        "REFLECT", 1, "SAME", DT_HALF);
309 }
310 
TEST_F(FusedResizePadConvOpTest,IdentityComparativeFloat)311 TEST_F(FusedResizePadConvOpTest, IdentityComparativeFloat) {
312   CompareFusedAndSeparate<float>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
313                                  "REFLECT", 1, "SAME", DT_FLOAT);
314 }
315 
TEST_F(FusedResizePadConvOpTest,IdentityComparativeDouble)316 TEST_F(FusedResizePadConvOpTest, IdentityComparativeDouble) {
317   CompareFusedAndSeparate<double>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
318                                   "REFLECT", 1, "SAME", DT_DOUBLE);
319 }
320 
TEST_F(FusedResizePadConvOpTest,ConvOnlyComparative)321 TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
322   CompareFusedAndSeparate<float>(10, 10, 3, 10, 10, 0, 0, 4, 4, false,
323                                  "REFLECT", 1, "SAME", DT_FLOAT);
324 }
325 
TEST_F(FusedResizePadConvOpTest,ResizeOnlyComparative)326 TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
327   CompareFusedAndSeparate<float>(10, 10, 1, 20, 20, 0, 0, 1, 1, false,
328                                  "REFLECT", 1, "SAME", DT_FLOAT);
329 }
330 
TEST_F(FusedResizePadConvOpTest,ResizeAndConvComparative)331 TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
332   CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
333                                  "SAME", DT_FLOAT);
334 }
335 
TEST_F(FusedResizePadConvOpTest,ResizeAlignAndConvComparative)336 TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
337   CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
338                                  "SAME", DT_FLOAT);
339 }
340 
TEST_F(FusedResizePadConvOpTest,ResizeAndConvStridedComparative)341 TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
342   CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
343                                  "SAME", DT_FLOAT);
344 }
345 
TEST_F(FusedResizePadConvOpTest,ResizeAlignAndConvValidComparative)346 TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
347   CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
348                                  "VALID", DT_FLOAT);
349 }
350 
TEST_F(FusedResizePadConvOpTest,PadOnlyComparative)351 TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
352   CompareFusedAndSeparate<float>(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
353                                  "SAME", DT_FLOAT);
354 }
355 
TEST_F(FusedResizePadConvOpTest,PadOnlyWithChannelsComparative)356 TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
357   CompareFusedAndSeparate<float>(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
358                                  "SAME", DT_FLOAT);
359 }
360 
TEST_F(FusedResizePadConvOpTest,ResizeAndPadComparative)361 TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
362   CompareFusedAndSeparate<float>(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
363                                  "SAME", DT_FLOAT);
364 }
365 
TEST_F(FusedResizePadConvOpTest,PadOnlySymmetricComparative)366 TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
367   CompareFusedAndSeparate<float>(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC",
368                                  1, "SAME", DT_FLOAT);
369 }
370 
TEST_F(FusedResizePadConvOpTest,ResizeAndPadSymmetricComparative)371 TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
372   CompareFusedAndSeparate<float>(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC",
373                                  1, "SAME", DT_FLOAT);
374 }
375 
TEST_F(FusedResizePadConvOpTest,ResizeAndPadSymmetricComparativeLarge)376 TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparativeLarge) {
377   CompareFusedAndSeparate<float>(1000, 1000, 3, 1006, 1006, 2, 2, 1, 1, false,
378                                  "SYMMETRIC", 1, "SAME", DT_FLOAT);
379 }
380 
TEST_F(FusedResizePadConvOpTest,NoResizeIdentityComparativeHalf)381 TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeHalf) {
382   CompareFusedPadOnlyAndSeparate<Eigen::half>(10, 10, 1, 0, 0, 1, 1, "REFLECT",
383                                               1, "SAME", DT_HALF);
384 }
385 
TEST_F(FusedResizePadConvOpTest,NoResizeIdentityComparativeFloat)386 TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeFloat) {
387   CompareFusedPadOnlyAndSeparate<float>(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1,
388                                         "SAME", DT_FLOAT);
389 }
390 
TEST_F(FusedResizePadConvOpTest,NoResizeIdentityComparativeDouble)391 TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeDouble) {
392   CompareFusedPadOnlyAndSeparate<double>(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1,
393                                          "SAME", DT_DOUBLE);
394 }
395 
TEST_F(FusedResizePadConvOpTest,NoResizeConvOnlyComparative)396 TEST_F(FusedResizePadConvOpTest, NoResizeConvOnlyComparative) {
397   CompareFusedPadOnlyAndSeparate<float>(10, 10, 3, 0, 0, 4, 4, "REFLECT", 1,
398                                         "SAME", DT_FLOAT);
399 }
400 
TEST_F(FusedResizePadConvOpTest,NoResizePadOnlyComparative)401 TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyComparative) {
402   CompareFusedPadOnlyAndSeparate<float>(4, 4, 1, 2, 2, 1, 1, "REFLECT", 1,
403                                         "SAME", DT_FLOAT);
404 }
405 
TEST_F(FusedResizePadConvOpTest,NoResizePadOnlyWithChannelsComparative)406 TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyWithChannelsComparative) {
407   CompareFusedPadOnlyAndSeparate<float>(4, 4, 3, 2, 2, 1, 1, "REFLECT", 1,
408                                         "SAME", DT_FLOAT);
409 }
410 
TEST_F(FusedResizePadConvOpTest,NoResizePadOnlySymmetricComparative)411 TEST_F(FusedResizePadConvOpTest, NoResizePadOnlySymmetricComparative) {
412   CompareFusedPadOnlyAndSeparate<float>(4, 4, 1, 2, 2, 1, 1, "SYMMETRIC", 1,
413                                         "SAME", DT_FLOAT);
414 }
415 
416 class ConvOpTest : public OpsTestBase {
417  protected:
HandwrittenConv()418   void HandwrittenConv() {
419     const int stride = 1;
420     TF_EXPECT_OK(NodeDefBuilder("conv_op", "Conv2D")
421                      .Input(FakeInput(DT_FLOAT))
422                      .Input(FakeInput(DT_FLOAT))
423                      .Attr("T", DT_FLOAT)
424                      .Attr("strides", {1, stride, stride, 1})
425                      .Attr("padding", "SAME")
426                      .Finalize(node_def()));
427     TF_EXPECT_OK(InitOp());
428     const int depth = 1;
429     const int image_width = 4;
430     const int image_height = 3;
431     const int image_batch_count = 1;
432     // The image matrix is:
433     // |  1 |  2 |  3 |  4 |
434     // |  5 |  6 |  7 |  8 |
435     // |  9 | 10 | 11 | 12 |
436     Tensor image(DT_FLOAT,
437                  {image_batch_count, image_height, image_width, depth});
438     test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
439 
440     // The filter matrix is:
441     // | 1 | 4 | 7 |
442     // | 2 | 5 | 8 |
443     // | 3 | 6 | 9 |
444     const int filter_size = 3;
445     const int filter_count = 1;
446     Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
447     test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
448 
449     AddInputFromArray<float>(image.shape(), image.flat<float>());
450     AddInputFromArray<float>(filter.shape(), filter.flat<float>());
451     TF_ASSERT_OK(RunOpKernel());
452 
453     // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
454     // the input set to zero because we're using the 'SAME' padding mode.
455     // The calculations behind the expected output are:
456     // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
457     // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
458     // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
459     // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
460     // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
461     // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
462     // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
463     // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
464     // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
465     // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
466     // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
467     // (1*7)+(4*8)+(7*0)+(2*11)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
468     // This means we should end up with this matrix:
469     // |  105  |  150  |  183  |   95  |
470     // |  235  |  312  |  357  |  178  |
471     // |  187  |  234  |  261  |  121  |
472     const int expected_width = image_width;
473     const int expected_height = image_height * filter_count;
474     Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
475                                            expected_width, filter_count}));
476     test::FillValues<float>(
477         &expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
478     const Tensor& output = *GetOutput(0);
479     test::ExpectTensorNear<float>(expected, output, 1e-5);
480   }
481 
AnisotropicStrides()482   void AnisotropicStrides() {
483     const int stride_width = 3;
484     const int stride_height = 1;
485     TF_EXPECT_OK(NodeDefBuilder("conv_op", "Conv2D")
486                      .Input(FakeInput(DT_FLOAT))
487                      .Input(FakeInput(DT_FLOAT))
488                      .Attr("T", DT_FLOAT)
489                      .Attr("strides", {1, stride_height, stride_width, 1})
490                      .Attr("padding", "VALID")
491                      .Finalize(node_def()));
492     TF_EXPECT_OK(InitOp());
493     const int depth = 1;
494     const int image_width = 6;
495     const int image_height = 3;
496     const int image_batch_count = 1;
497     Tensor image(DT_FLOAT,
498                  {image_batch_count, image_height, image_width, depth});
499     test::FillValues<float>(&image, {
500                                         3, 2, 1, -1, -2, -3,  //
501                                         4, 3, 2, -2, -3, -4,  //
502                                         5, 4, 3, -3, -4, -5,  //
503                                     });
504     const int filter_size = 2;
505     const int filter_count = 1;
506     Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
507     test::FillValues<float>(&filter, {
508                                          1, 2,  //
509                                          3, 4,  //
510                                      });
511 
512     AddInputFromArray<float>(image.shape(), image.flat<float>());
513     AddInputFromArray<float>(filter.shape(), filter.flat<float>());
514     TF_ASSERT_OK(RunOpKernel());
515 
516     const int expected_width = 2;
517     const int expected_height = 2;
518     Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
519                                            expected_width, filter_count}));
520     test::FillValues<float>(&expected, {31, -23, 41, -33});
521     const Tensor& output = *GetOutput(0);
522     test::ExpectTensorNear<float>(expected, output, 1e-5);
523   }
524 };
525 
TEST_F(ConvOpTest,HandwrittenConv)526 TEST_F(ConvOpTest, HandwrittenConv) { HandwrittenConv(); }
527 
TEST_F(ConvOpTest,AnisotropicStride)528 TEST_F(ConvOpTest, AnisotropicStride) { AnisotropicStrides(); }
529 
530 template <typename T>
531 class FusedConv2DOpTest : public OpsTestBase {
532  protected:
533   static constexpr int kDepth = 3;
534   static constexpr int kImageWidth = 32;
535   static constexpr int kImageHeight = 32;
536   static constexpr int kImageBatchCount = 8;
537 
538   using BiasAddGraphRunner =
539       std::function<void(const Tensor& input_data, const Tensor& filter_data,
540                          const Tensor& bias_data, Tensor* out)>;
541 
542   using BatchNormGraphRunner = std::function<void(
543       const Tensor& input_data, const Tensor& filter_data,
544       const Tensor& scale_data, const Tensor& offset_data,
545       const Tensor& mean_data, const Tensor& variance_data, Tensor* out)>;
546 
547   // Runs a Tensorflow graph defined by the root scope, and fetches the result
548   // of 'fetch' node into the output Tensor. Optional `fetch_node` parameter
549   // allows to define a fetch node directly using a NodeDef for the ops that are
550   // not supported by the C++ Api.
RunAndFetch(const tensorflow::Scope & root,const string & fetch,Tensor * output,bool allow_gpu_device,const NodeDef * fetch_node=nullptr)551   void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
552                    Tensor* output, bool allow_gpu_device,
553                    const NodeDef* fetch_node = nullptr) {
554     tensorflow::GraphDef graph;
555     TF_ASSERT_OK(root.ToGraphDef(&graph));
556 
557     if (fetch_node) {
558       *graph.add_node() = *fetch_node;
559     }
560 
561     // We really want to make sure that graph executed exactly as we passed it
562     // to the session, so we disable various optimizations.
563     tensorflow::SessionOptions session_options;
564 
565     // Disable common runtime constant folding.
566     session_options.config.mutable_graph_options()
567         ->mutable_optimizer_options()
568         ->set_opt_level(OptimizerOptions::L0);
569 
570     // Disable Grappler optimizations for tests.
571     tensorflow::RewriterConfig* cfg =
572         session_options.config.mutable_graph_options()
573             ->mutable_rewrite_options();
574     cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
575     cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
576     cfg->set_remapping(tensorflow::RewriterConfig::OFF);
577 
578     std::unique_ptr<tensorflow::Session> session(
579         tensorflow::NewSession(session_options));
580 
581     std::vector<DeviceAttributes> available_devices;
582     TF_ASSERT_OK(session->ListDevices(&available_devices))
583         << "Failed to get available session devices";
584 
585     // Check if session has an available GPU device.
586     const bool has_gpu_device =
587         absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
588           return device.device_type() == DEVICE_GPU;
589         });
590 
591     // Some of the `FusedConv2D` fusion types are implemented only for CPU, and
592     // in this test we don't want to compare GPU vs CPU numbers, so place all
593     // nodes on CPU in this case.
594     const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
595 
596     const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
597     for (NodeDef& mutable_node : *graph.mutable_node()) {
598       mutable_node.set_device(device);
599     }
600 
601     TF_ASSERT_OK(session->Create(graph));
602 
603     std::vector<Tensor> unfused_tensors;
604     TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
605 
606     *output = unfused_tensors[0];
607   }
608 
RunConv2DWithBias(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,Tensor * output,bool allow_gpu_device=false,int stride=1)609   void RunConv2DWithBias(const Tensor& input_data, const Tensor& filter_data,
610                          const Tensor& bias_data, Tensor* output,
611                          bool allow_gpu_device = false, int stride = 1) {
612     Scope root = tensorflow::Scope::NewRootScope();
613 
614     ops::Conv2D conv = ops::Conv2D(
615         root.WithOpName("conv"),
616         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
617         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
618         {1, stride, stride, 1}, "SAME");
619 
620     ops::BiasAdd with_bias = ops::BiasAdd(
621         root.WithOpName("with_bias"), conv,
622         ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
623 
624     RunAndFetch(root, "with_bias", output, allow_gpu_device);
625   }
626 
RunConv2DWithBiasAndRelu(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,Tensor * output,bool allow_gpu_device=false,int stride=1)627   void RunConv2DWithBiasAndRelu(const Tensor& input_data,
628                                 const Tensor& filter_data,
629                                 const Tensor& bias_data, Tensor* output,
630                                 bool allow_gpu_device = false, int stride = 1) {
631     Scope root = tensorflow::Scope::NewRootScope();
632 
633     ops::Conv2D conv = ops::Conv2D(
634         root.WithOpName("conv"),
635         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
636         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
637         {1, stride, stride, 1}, "SAME");
638 
639     ops::BiasAdd with_bias = ops::BiasAdd(
640         root.WithOpName("with_bias"), conv,
641         ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
642 
643     ops::Relu with_relu = ops::Relu(root.WithOpName("with_relu"), with_bias);
644 
645     RunAndFetch(root, "with_relu", output, allow_gpu_device);
646   }
647 
RunConv2DWithBatchNorm(const Tensor & input_data,const Tensor & filter_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & variance_data,Tensor * output,bool allow_gpu_device=false,int stride=1)648   void RunConv2DWithBatchNorm(const Tensor& input_data,
649                               const Tensor& filter_data,
650                               const Tensor& scale_data,
651                               const Tensor& offset_data,
652                               const Tensor& mean_data,
653                               const Tensor& variance_data, Tensor* output,
654                               bool allow_gpu_device = false, int stride = 1) {
655     Scope root = tensorflow::Scope::NewRootScope();
656 
657     ops::Conv2D conv = ops::Conv2D(
658         root.WithOpName("conv"),
659         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
660         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
661         {1, stride, stride, 1}, "SAME");
662 
663     ops::FusedBatchNorm::Attrs attr;
664     attr = attr.IsTraining(false);
665 
666     ops::FusedBatchNorm with_fused_batch_norm = ops::FusedBatchNorm(
667         root.WithOpName("with_fused_batch_norm"), conv,
668         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
669         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
670         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
671         ops::Const(root.WithOpName("var"), Input::Initializer(variance_data)),
672         attr);
673 
674     RunAndFetch(root, "with_fused_batch_norm", output, allow_gpu_device);
675   }
676 
RunConv2DWithBatchNormAndRelu(const Tensor & input_data,const Tensor & filter_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & variance_data,Tensor * output,bool allow_gpu_device=false,int stride=1)677   void RunConv2DWithBatchNormAndRelu(
678       const Tensor& input_data, const Tensor& filter_data,
679       const Tensor& scale_data, const Tensor& offset_data,
680       const Tensor& mean_data, const Tensor& variance_data, Tensor* output,
681       bool allow_gpu_device = false, int stride = 1) {
682     Scope root = tensorflow::Scope::NewRootScope();
683 
684     ops::Conv2D conv = ops::Conv2D(
685         root.WithOpName("conv"),
686         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
687         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
688         {1, stride, stride, 1}, "SAME");
689 
690     ops::FusedBatchNorm::Attrs attr;
691     attr = attr.IsTraining(false);
692 
693     ops::FusedBatchNorm with_fused_batch_norm = ops::FusedBatchNorm(
694         root.WithOpName("with_fused_batch_norm"), conv,
695         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
696         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
697         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
698         ops::Const(root.WithOpName("var"), Input::Initializer(variance_data)),
699         attr);
700 
701     ops::Relu with_relu =
702         ops::Relu(root.WithOpName("with_relu"), with_fused_batch_norm.y);
703 
704     RunAndFetch(root, "with_relu", output, allow_gpu_device);
705   }
706 
RunFusedConv2DOp(const Tensor & input_data,const Tensor & filter_data,const std::vector<Tensor> & args_data,const std::vector<string> & fused_ops,Tensor * output,bool allow_gpu_device=false,int stride=1)707   void RunFusedConv2DOp(const Tensor& input_data, const Tensor& filter_data,
708                         const std::vector<Tensor>& args_data,
709                         const std::vector<string>& fused_ops, Tensor* output,
710                         bool allow_gpu_device = false, int stride = 1) {
711     Scope root = tensorflow::Scope::NewRootScope();
712 
713     DataType dtype = DataTypeToEnum<T>::v();
714     int num_args = static_cast<int>(args_data.size());
715 
716     Output input =
717         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
718     Output filter =
719         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data));
720 
721     std::vector<NodeDefBuilder::NodeOut> args;
722     for (int i = 0; i < num_args; ++i) {
723       Output arg = ops::Const(root.WithOpName(absl::StrCat("arg", i)),
724                               Input::Initializer(args_data[i]));
725       args.emplace_back(arg.name(), 0, dtype);
726     }
727 
728     NodeDef fused_conv2d;
729     TF_EXPECT_OK(NodeDefBuilder("fused_conv", "_FusedConv2D")
730                      .Input({input.name(), 0, dtype})
731                      .Input({filter.name(), 0, dtype})
732                      .Input(args)
733                      .Attr("num_args", num_args)
734                      .Attr("T", dtype)
735                      .Attr("strides", {1, stride, stride, 1})
736                      .Attr("padding", "SAME")
737                      .Attr("fused_ops", fused_ops)
738                      .Finalize(&fused_conv2d));
739 
740     RunAndFetch(root, fused_conv2d.name(), output, allow_gpu_device,
741                 &fused_conv2d);
742   }
743 
VerifyBiasAddTensorsNear(int depth,int image_width,int image_height,int image_batch_count,int filter_size,int filter_count,const BiasAddGraphRunner & run_default,const BiasAddGraphRunner & run_fused)744   void VerifyBiasAddTensorsNear(int depth, int image_width, int image_height,
745                                 int image_batch_count, int filter_size,
746                                 int filter_count,
747                                 const BiasAddGraphRunner& run_default,
748                                 const BiasAddGraphRunner& run_fused) {
749     DataType dtype = DataTypeToEnum<T>::v();
750 
751     Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
752     image.flat<T>() = image.flat<T>().setRandom();
753 
754     // Add some negative values to filter to properly test Relu.
755     Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
756     filter.flat<T>() = filter.flat<T>().setRandom();
757     filter.flat<T>() -= filter.flat<T>().constant(static_cast<T>(0.5f));
758 
759     const int bias_size = filter_count;
760     Tensor bias(dtype, {bias_size});
761     bias.flat<T>() = bias.flat<T>().setRandom();
762     bias.flat<T>() += bias.flat<T>().constant(static_cast<T>(0.5f));
763 
764     Tensor conv_2d;
765     Tensor fused_conv_2d;
766 
767     run_default(image, filter, bias, &conv_2d);
768     run_fused(image, filter, bias, &fused_conv_2d);
769 
770     ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
771     ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
772 
773     // NOTE(intel-tf): When filter_size is equal to the input image size,
774     // conv2d essentially is element-wise multiplication followed by
775     // a full sum reduction, which causes larger numerical error
776     // than usual cases.
777     if (image_width == filter_size && image_height == filter_size) {
778       test::ExpectClose(conv_2d, fused_conv_2d, /*atol=*/1e-4);
779     } else {
780       test::ExpectClose(conv_2d, fused_conv_2d, /*atol=*/1e-6);
781     }
782   }
783 
VerifyFusedBatchNormTensorsNear(int depth,int image_width,int image_height,int image_batch_count,int filter_size,int filter_count,const BatchNormGraphRunner & run_default,const BatchNormGraphRunner & run_fused)784   void VerifyFusedBatchNormTensorsNear(int depth, int image_width,
785                                        int image_height, int image_batch_count,
786                                        int filter_size, int filter_count,
787                                        const BatchNormGraphRunner& run_default,
788                                        const BatchNormGraphRunner& run_fused) {
789     DataType dtype = DataTypeToEnum<T>::v();
790 
791     Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
792     image.flat<T>() = image.flat<T>().setRandom();
793 
794     // Add some negative values to filter to properly test Relu.
795     Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
796     filter.flat<T>() = filter.flat<T>().setRandom();
797     filter.flat<T>() -= filter.flat<T>().constant(static_cast<T>(0.5f));
798 
799     const int scale_size = filter_count;
800 
801     Tensor scale(dtype, {scale_size});
802     scale.flat<T>() = scale.flat<T>().setRandom();
803 
804     Tensor offset(dtype, {scale_size});
805     offset.flat<T>() = offset.flat<T>().setRandom();
806 
807     Tensor mean(dtype, {scale_size});
808     mean.flat<T>() = mean.flat<T>().setRandom();
809 
810     Tensor variance(dtype, {scale_size});
811     variance.flat<T>() = variance.flat<T>().setRandom();
812     variance.flat<T>() += variance.flat<T>().constant(static_cast<T>(0.5f));
813 
814     Tensor conv_2d;
815     Tensor fused_conv_2d;
816 
817     run_default(image, filter, scale, offset, mean, variance, &conv_2d);
818     run_fused(image, filter, scale, offset, mean, variance, &fused_conv_2d);
819 
820     ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
821     ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
822 
823     // NOTE(intel-tf): When filter_size is equal to the input image size,
824     // conv2d essentially is element-wise multiplication followed by
825     // a full sum reduction, which causes larger numerical error
826     // than usual cases.
827     if (image_width == filter_size && image_height == filter_size) {
828       test::ExpectClose(conv_2d, fused_conv_2d, /*atol=*/1e-4);
829     } else {
830       test::ExpectClose(conv_2d, fused_conv_2d, /*atol=*/1e-6);
831     }
832   }
833 
834   // Verifies that computing Conv2D+BiasAdd in a graph is identical to
835   // FusedConv2D.
VerifyConv2DWithBias(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)836   void VerifyConv2DWithBias(int filter_size, int filter_count,
837                             int depth = kDepth, int image_width = kImageWidth,
838                             int image_height = kImageHeight,
839                             int image_batch_count = kImageBatchCount) {
840     const BiasAddGraphRunner run_default =
841         [this](const Tensor& input_data, const Tensor& filter_data,
842                const Tensor& bias_data, Tensor* out) {
843           RunConv2DWithBias(input_data, filter_data, bias_data, out);
844         };
845 
846     const BiasAddGraphRunner run_fused = [this](const Tensor& input_data,
847                                                 const Tensor& filter_data,
848                                                 const Tensor& bias_data,
849                                                 Tensor* out) {
850       RunFusedConv2DOp(input_data, filter_data, {bias_data}, {"BiasAdd"}, out);
851     };
852 
853     VerifyBiasAddTensorsNear(depth, image_width, image_height,
854                              image_batch_count, filter_size, filter_count,
855                              run_default, run_fused);
856   }
857 
858   // Verifies that computing Conv2D+BiasAdd+Relu in a graph is identical to
859   // FusedConv2D.
VerifyConv2DWithBiasAndRelu(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)860   void VerifyConv2DWithBiasAndRelu(int filter_size, int filter_count,
861                                    int depth = kDepth,
862                                    int image_width = kImageWidth,
863                                    int image_height = kImageHeight,
864                                    int image_batch_count = kImageBatchCount) {
865     const BiasAddGraphRunner run_default =
866         [this](const Tensor& input_data, const Tensor& filter_data,
867                const Tensor& bias_data, Tensor* out) {
868           RunConv2DWithBiasAndRelu(input_data, filter_data, bias_data, out,
869                                    /*allow_gpu_device=*/true);
870         };
871 
872     const BiasAddGraphRunner run_fused =
873         [this](const Tensor& input_data, const Tensor& filter_data,
874                const Tensor& bias_data, Tensor* out) {
875           RunFusedConv2DOp(input_data, filter_data, {bias_data},
876                            {"BiasAdd", "Relu"}, out, /*allow_gpu_device=*/true);
877         };
878 
879     VerifyBiasAddTensorsNear(depth, image_width, image_height,
880                              image_batch_count, filter_size, filter_count,
881                              run_default, run_fused);
882   }
883 
884   // Verifies that computing Conv2D+FusedBatchNorm in a graph is identical to
885   // FusedConv2D.
VerifyConv2DWithBatchNorm(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)886   void VerifyConv2DWithBatchNorm(int filter_size, int filter_count,
887                                  int depth = kDepth,
888                                  int image_width = kImageWidth,
889                                  int image_height = kImageHeight,
890                                  int image_batch_count = kImageBatchCount) {
891     const BatchNormGraphRunner run_default =
892         [this](const Tensor& input_data, const Tensor& filter_data,
893                const Tensor& scale_data, const Tensor& offset_data,
894                const Tensor& mean_data, const Tensor& variance_data,
895                Tensor* out) {
896           RunConv2DWithBatchNorm(input_data, filter_data, scale_data,
897                                  offset_data, mean_data, variance_data, out);
898         };
899 
900     const BatchNormGraphRunner run_fused =
901         [this](const Tensor& input_data, const Tensor& filter_data,
902                const Tensor& scale_data, const Tensor& offset_data,
903                const Tensor& mean_data, const Tensor& variance_data,
904                Tensor* out) {
905           RunFusedConv2DOp(input_data, filter_data,
906                            {scale_data, offset_data, mean_data, variance_data},
907                            {"FusedBatchNorm"}, out);
908         };
909 
910     VerifyFusedBatchNormTensorsNear(depth, image_width, image_height,
911                                     image_batch_count, filter_size,
912                                     filter_count, run_default, run_fused);
913   }
914 
915   // Verifies that computing Conv2D+FusedBatchNorm+Relu in a graph is identical
916   // to FusedConv2D.
VerifyConv2DWithBatchNormAndRelu(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)917   void VerifyConv2DWithBatchNormAndRelu(
918       int filter_size, int filter_count, int depth = kDepth,
919       int image_width = kImageWidth, int image_height = kImageHeight,
920       int image_batch_count = kImageBatchCount) {
921     const BatchNormGraphRunner run_default =
922         [this](const Tensor& input_data, const Tensor& filter_data,
923                const Tensor& scale_data, const Tensor& offset_data,
924                const Tensor& mean_data, const Tensor& variance_data,
925                Tensor* out) {
926           RunConv2DWithBatchNormAndRelu(input_data, filter_data, scale_data,
927                                         offset_data, mean_data, variance_data,
928                                         out);
929         };
930 
931     const BatchNormGraphRunner run_fused =
932         [this](const Tensor& input_data, const Tensor& filter_data,
933                const Tensor& scale_data, const Tensor& offset_data,
934                const Tensor& mean_data, const Tensor& variance_data,
935                Tensor* out) {
936           RunFusedConv2DOp(input_data, filter_data,
937                            {scale_data, offset_data, mean_data, variance_data},
938                            {"FusedBatchNorm", "Relu"}, out);
939         };
940 
941     VerifyFusedBatchNormTensorsNear(depth, image_width, image_height,
942                                     image_batch_count, filter_size,
943                                     filter_count, run_default, run_fused);
944   }
945 };
946 
947 // Conv2D with BatchNorm can be tested only with `T=float`, because default
948 // `FusedBatchNorm` kernel supports only floats for scale, mean and variance.
949 
950 template <typename T>
951 class FusedConv2DWithBiasOpTest : public FusedConv2DOpTest<T> {};
952 template <typename T>
953 class FusedConv2DWithBatchNormOpTest : public FusedConv2DOpTest<T> {};
954 
955 TYPED_TEST_SUITE_P(FusedConv2DWithBiasOpTest);
956 TYPED_TEST_SUITE_P(FusedConv2DWithBatchNormOpTest);
957 
958 // -------------------------------------------------------------------------- //
959 // Conv2D + BiasAdd + {Relu}                                                  //
960 // -------------------------------------------------------------------------- //
961 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,OneByOneConvolution)962 TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolution) {
963   const int filter_size = 1;
964   const int filter_count = 12;
965   this->VerifyConv2DWithBias(filter_size, filter_count);
966 }
967 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,ImageSizeConvolution)968 TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolution) {
969   const int filter_size = TestFixture::kImageWidth;
970   const int filter_count = 12;
971   this->VerifyConv2DWithBias(filter_size, filter_count);
972 }
973 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,SpatialConvolution)974 TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolution) {
975   const int filter_size = 3;
976   const int filter_count = 12;
977   this->VerifyConv2DWithBias(filter_size, filter_count);
978 }
979 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,OneByOneConvolutionAndRelu)980 TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndRelu) {
981   const int filter_size = 1;
982   const int filter_count = 12;
983   this->VerifyConv2DWithBiasAndRelu(filter_size, filter_count);
984 }
985 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,ImageSizeConvolutionAndRelu)986 TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndRelu) {
987   const int filter_size = TestFixture::kImageWidth;
988   const int filter_count = 12;
989   this->VerifyConv2DWithBiasAndRelu(filter_size, filter_count);
990 }
991 
TYPED_TEST_P(FusedConv2DWithBiasOpTest,SpatialConvolutionAndRelu)992 TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
993   const int filter_size = 3;
994   const int filter_count = 12;
995   this->VerifyConv2DWithBiasAndRelu(filter_size, filter_count);
996 }
997 
998 // -------------------------------------------------------------------------- //
999 // Conv2D + FusedBatchNorm + {Relu}                                           //
1000 // -------------------------------------------------------------------------- //
1001 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,OneByOneConvolution)1002 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, OneByOneConvolution) {
1003   const int filter_size = 1;
1004   const int filter_count = 12;
1005   this->VerifyConv2DWithBatchNorm(filter_size, filter_count);
1006 }
1007 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,ImageSizeConvolution)1008 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ImageSizeConvolution) {
1009   const int filter_size = TestFixture::kImageWidth;
1010   const int filter_count = 12;
1011   this->VerifyConv2DWithBatchNorm(filter_size, filter_count);
1012 }
1013 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,SpatialConvolution)1014 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, SpatialConvolution) {
1015   const int filter_size = 3;
1016   const int filter_count = 12;
1017   this->VerifyConv2DWithBatchNorm(filter_size, filter_count);
1018 }
1019 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,OneByOneConvolutionAndRelu)1020 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, OneByOneConvolutionAndRelu) {
1021   const int filter_size = 1;
1022   const int filter_count = 12;
1023   this->VerifyConv2DWithBatchNormAndRelu(filter_size, filter_count);
1024 }
1025 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,ImageSizeConvolutionAndRelu)1026 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ImageSizeConvolutionAndRelu) {
1027   const int filter_size = TestFixture::kImageWidth;
1028   const int filter_count = 12;
1029   this->VerifyConv2DWithBatchNormAndRelu(filter_size, filter_count);
1030 }
1031 
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,SpatialConvolutionAndRelu)1032 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, SpatialConvolutionAndRelu) {
1033   const int filter_size = 3;
1034   const int filter_count = 12;
1035   this->VerifyConv2DWithBatchNormAndRelu(filter_size, filter_count);
1036 }
1037 
1038 REGISTER_TYPED_TEST_SUITE_P(FusedConv2DWithBiasOpTest,    //
1039                             OneByOneConvolution,          //
1040                             ImageSizeConvolution,         //
1041                             SpatialConvolution,           //
1042                             OneByOneConvolutionAndRelu,   //
1043                             ImageSizeConvolutionAndRelu,  //
1044                             SpatialConvolutionAndRelu);
1045 
1046 REGISTER_TYPED_TEST_SUITE_P(FusedConv2DWithBatchNormOpTest,  //
1047                             OneByOneConvolution,             //
1048                             ImageSizeConvolution,            //
1049                             SpatialConvolution,              //
1050                             OneByOneConvolutionAndRelu,      //
1051                             ImageSizeConvolutionAndRelu,     //
1052                             SpatialConvolutionAndRelu);
1053 
1054 using FusedBiasAddDataTypes = ::testing::Types<float, double>;
1055 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBiasOpTest,
1056                                FusedBiasAddDataTypes);
1057 
1058 using FusedBatchNormDataTypes = ::testing::Types<float>;
1059 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBatchNormOpTest,
1060                                FusedBatchNormDataTypes);
1061 
1062 ////////////////////////////////////////////////////////////////////////////////
1063 // Performance benchmarks for the FusedConv2DWithBiasOp.                      //
1064 ////////////////////////////////////////////////////////////////////////////////
1065 
1066 struct Conv2DGraph {
1067   Graph* graph;
1068   Node* conv2d;
1069 };
1070 
1071 struct Conv2DWithBiasGraph {
1072   Graph* graph;
1073   Node* conv2d;
1074   Node* bias;
1075 };
1076 
1077 struct Conv2DWithBiasAndReluGraph {
1078   Graph* graph;
1079   Node* conv2d;
1080   Node* bias;
1081   Node* relu;
1082 };
1083 
1084 struct Conv2DWithBatchNormGraph {
1085   Graph* graph;
1086   Node* conv2d;
1087   Node* batch_norm;
1088 };
1089 
1090 struct Conv2DWithBatchNormAndReluGraph {
1091   Graph* graph;
1092   Node* conv2d;
1093   Node* batch_norm;
1094   Node* relu;
1095 };
1096 
MakeRandomTensor(const TensorShape & shape)1097 static Tensor MakeRandomTensor(const TensorShape& shape) {
1098   Tensor tensor(DT_FLOAT, TensorShape(shape));
1099   tensor.flat<float>() = tensor.flat<float>().setRandom();
1100   return tensor;
1101 }
1102 
1103 // Creates a simple Tensorflow graph with single Conv2D node.
Conv2D(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth)1104 static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
1105                           int filter_w, int filter_h, int out_depth) {
1106   Graph* graph = new Graph(OpRegistry::Global());
1107 
1108   Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
1109   Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
1110 
1111   Node* images = test::graph::Constant(graph, images_t, "images");
1112   Node* filter = test::graph::Constant(graph, filter_t, "filter");
1113 
1114   Node* conv2d;
1115   TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "Conv2D")
1116                   .Input(images)
1117                   .Input(filter)
1118                   .Attr("T", DT_FLOAT)
1119                   .Attr("strides", {1, 1, 1, 1})
1120                   .Attr("padding", "SAME")
1121                   .Finalize(graph, &conv2d));
1122 
1123   return {graph, conv2d};
1124 }
1125 
1126 // Creates a Tensorflow graph with a Conv2D node followed by BiasAdd.
Conv2DWithBias(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth)1127 static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
1128                                           int in_depth, int filter_w,
1129                                           int filter_h, int out_depth) {
1130   Conv2DGraph conv_graph =
1131       Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
1132 
1133   Graph* graph = conv_graph.graph;
1134   Node* conv2d = conv_graph.conv2d;
1135 
1136   Tensor bias_t = MakeRandomTensor({out_depth});
1137   Node* bias = test::graph::Constant(graph, bias_t, "bias");
1138 
1139   Node* out;
1140   TF_CHECK_OK(NodeBuilder(graph->NewName("bias"), "BiasAdd")
1141                   .Input(conv2d)
1142                   .Input(bias)
1143                   .Attr("T", DT_FLOAT)
1144                   .Attr("data_format", "NHWC")
1145                   .Finalize(graph, &out));
1146 
1147   return {graph, conv2d, out};
1148 }
1149 
1150 // Creates a Tensorflow graph with a Conv2D node followed by BiasAdd and Relu.
Conv2DWithBiasAndRelu(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth)1151 static Conv2DWithBiasAndReluGraph Conv2DWithBiasAndRelu(int batch, int height,
1152                                                         int width, int in_depth,
1153                                                         int filter_w,
1154                                                         int filter_h,
1155                                                         int out_depth) {
1156   Conv2DWithBiasGraph conv_graph = Conv2DWithBias(
1157       batch, height, width, in_depth, filter_w, filter_h, out_depth);
1158 
1159   Graph* graph = conv_graph.graph;
1160   Node* conv2d = conv_graph.conv2d;
1161   Node* bias = conv_graph.bias;
1162 
1163   Node* relu;
1164   TF_CHECK_OK(NodeBuilder(graph->NewName("relu"), "Relu")
1165                   .Input(bias)
1166                   .Attr("T", DT_FLOAT)
1167                   .Finalize(graph, &relu));
1168 
1169   return {graph, conv2d, bias, relu};
1170 }
1171 
1172 // Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm.
Conv2DWithBatchNorm(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth)1173 static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
1174                                                     int width, int in_depth,
1175                                                     int filter_w, int filter_h,
1176                                                     int out_depth) {
1177   Conv2DGraph conv_graph =
1178       Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
1179 
1180   Graph* graph = conv_graph.graph;
1181   Node* conv2d = conv_graph.conv2d;
1182 
1183   Tensor scale_t = MakeRandomTensor({out_depth});
1184   Tensor offset_t = MakeRandomTensor({out_depth});
1185   Tensor mean_t = MakeRandomTensor({out_depth});
1186   Tensor variance_t = MakeRandomTensor({out_depth});
1187 
1188   Node* scale = test::graph::Constant(graph, scale_t, "scale");
1189   Node* offset = test::graph::Constant(graph, offset_t, "offset");
1190   Node* mean = test::graph::Constant(graph, mean_t, "mean");
1191   Node* variance = test::graph::Constant(graph, variance_t, "variance");
1192 
1193   Node* out;
1194   TF_CHECK_OK(NodeBuilder(graph->NewName("batch_norm"), "FusedBatchNorm")
1195                   .Input(conv2d)
1196                   .Input(scale)
1197                   .Input(offset)
1198                   .Input(mean)
1199                   .Input(variance)
1200                   .Attr("T", DT_FLOAT)
1201                   .Attr("is_training", false)
1202                   .Finalize(graph, &out));
1203 
1204   return {graph, conv2d, out};
1205 }
1206 
1207 // Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm and
1208 // Relu.
Conv2DWithBatchNormAndRelu(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth)1209 static Conv2DWithBatchNormAndReluGraph Conv2DWithBatchNormAndRelu(
1210     int batch, int height, int width, int in_depth, int filter_w, int filter_h,
1211     int out_depth) {
1212   Conv2DWithBatchNormGraph conv_graph = Conv2DWithBatchNorm(
1213       batch, height, width, in_depth, filter_w, filter_h, out_depth);
1214 
1215   Graph* graph = conv_graph.graph;
1216   Node* conv2d = conv_graph.conv2d;
1217   Node* batch_norm = conv_graph.batch_norm;
1218 
1219   Node* relu;
1220   TF_CHECK_OK(NodeBuilder(graph->NewName("relu"), "Relu")
1221                   .Input(batch_norm)
1222                   .Attr("T", DT_FLOAT)
1223                   .Finalize(graph, &relu));
1224 
1225   return {graph, conv2d, batch_norm, relu};
1226 }
1227 
1228 // Creates a tensorflow graph with a single FusedConv2D (with BiasAdd) node and
1229 // fuses into it additional computations (e.g. Relu).
FusedConv2DWithBias(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth,const std::vector<string> & fused_ops={})1230 static Graph* FusedConv2DWithBias(int batch, int height, int width,
1231                                   int in_depth, int filter_w, int filter_h,
1232                                   int out_depth,
1233                                   const std::vector<string>& fused_ops = {}) {
1234   Graph* graph = new Graph(OpRegistry::Global());
1235 
1236   Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
1237   Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
1238   Tensor bias_t = MakeRandomTensor({out_depth});
1239 
1240   Node* images = test::graph::Constant(graph, images_t, "images");
1241   Node* filter = test::graph::Constant(graph, filter_t, "filter");
1242   Node* bias = test::graph::Constant(graph, bias_t, "bias");
1243 
1244   std::vector<NodeBuilder::NodeOut> args = {bias};
1245 
1246   Node* conv;
1247   TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "_FusedConv2D")
1248                   .Input(images)
1249                   .Input(filter)
1250                   .Attr("num_args", 1)
1251                   .Input(args)
1252                   .Attr("T", DT_FLOAT)
1253                   .Attr("strides", {1, 1, 1, 1})
1254                   .Attr("padding", "SAME")
1255                   .Attr("fused_ops", fused_ops)
1256                   .Finalize(graph, &conv));
1257 
1258   return graph;
1259 }
1260 
1261 // Creates a tensorflow graph with a single FusedConv2D (with FusedBatchNorm)
1262 // node and fuses into it additional computations (e.g. Relu).
FusedConv2DWithBatchNorm(int batch,int height,int width,int in_depth,int filter_w,int filter_h,int out_depth,const std::vector<string> & fused_ops={})1263 static Graph* FusedConv2DWithBatchNorm(
1264     int batch, int height, int width, int in_depth, int filter_w, int filter_h,
1265     int out_depth, const std::vector<string>& fused_ops = {}) {
1266   Graph* graph = new Graph(OpRegistry::Global());
1267 
1268   Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
1269   Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
1270   Tensor scale_t = MakeRandomTensor({out_depth});
1271   Tensor offset_t = MakeRandomTensor({out_depth});
1272   Tensor mean_t = MakeRandomTensor({out_depth});
1273   Tensor variance_t = MakeRandomTensor({out_depth});
1274 
1275   Node* images = test::graph::Constant(graph, images_t, "images");
1276   Node* filter = test::graph::Constant(graph, filter_t, "filter");
1277   Node* scale = test::graph::Constant(graph, scale_t, "scale");
1278   Node* offset = test::graph::Constant(graph, offset_t, "offset");
1279   Node* mean = test::graph::Constant(graph, mean_t, "mean");
1280   Node* variance = test::graph::Constant(graph, variance_t, "variance");
1281 
1282   std::vector<NodeBuilder::NodeOut> args = {scale, offset, mean, variance};
1283 
1284   Node* conv;
1285   TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "_FusedConv2D")
1286                   .Input(images)
1287                   .Input(filter)
1288                   .Attr("num_args", 4)
1289                   .Input(args)
1290                   .Attr("T", DT_FLOAT)
1291                   .Attr("strides", {1, 1, 1, 1})
1292                   .Attr("padding", "SAME")
1293                   .Attr("fused_ops", fused_ops)
1294                   .Finalize(graph, &conv));
1295 
1296   return graph;
1297 }
1298 
1299 // Macro arguments names: --------------------------------------------------- //
1300 //    N: batch size
1301 //    H: height
1302 //    W: width
1303 //    C: channels
1304 //   FC: filter count
1305 //   FH: filter height
1306 //   FW: filter width
1307 
1308 #define BM_SETUP(N, H, W, C, type, LABEL, NAME)                               \
1309   testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * (C)); \
1310   testing::SetLabel(LABEL);
1311 
1312 #define BM_NAME(name, type, N, H, W, C, FW, FH, FC) \
1313   name##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
1314 
1315 #define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL)                       \
1316   static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) {  \
1317     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                               \
1318     test::Benchmark(#type, Conv2D(N, H, W, C, FW, FH, FC).graph).Run(iters); \
1319   }                                                                          \
1320   BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
1321 
1322 #define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)           \
1323   static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH,       \
1324                       FC)(int iters) {                                   \
1325     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                           \
1326     test::Benchmark(#type, Conv2DWithBias(N, H, W, C, FW, FH, FC).graph) \
1327         .Run(iters);                                                     \
1328   }                                                                      \
1329   BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
1330 
1331 #define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)     \
1332   static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
1333                       FC)(int iters) {                                    \
1334     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                            \
1335     test::Benchmark(#type,                                                \
1336                     Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC).graph)  \
1337         .Run(iters);                                                      \
1338   }                                                                       \
1339   BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
1340 
1341 #define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)           \
1342   static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH,       \
1343                       FC)(int iters) {                                        \
1344     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
1345     test::Benchmark(#type,                                                    \
1346                     FusedConv2DWithBias(N, H, W, C, FW, FH, FC, {"BiasAdd"})) \
1347         .Run(iters);                                                          \
1348   }                                                                           \
1349   BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
1350 
1351 #define BM_FusedConv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)     \
1352   static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
1353                       FC)(int iters) {                                         \
1354     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                 \
1355     test::Benchmark(#type, FusedConv2DWithBias(N, H, W, C, FW, FH, FC,         \
1356                                                {"BiasAdd", "Relu"}))           \
1357         .Run(iters);                                                           \
1358   }                                                                            \
1359   BENCHMARK(                                                                   \
1360       BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
1361 
1362 #define BM_Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL)           \
1363   static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH,       \
1364                       FC)(int iters) {                                        \
1365     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
1366     test::Benchmark(#type, Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC).graph) \
1367         .Run(iters);                                                          \
1368   }                                                                           \
1369   BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
1370 
1371 #define BM_Conv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)     \
1372   static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
1373                       FC)(int iters) {                                         \
1374     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                 \
1375     test::Benchmark(#type,                                                     \
1376                     Conv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC).graph)  \
1377         .Run(iters);                                                           \
1378   }                                                                            \
1379   BENCHMARK(                                                                   \
1380       BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, FC));
1381 
1382 #define BM_FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL)     \
1383   static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
1384                       FC)(int iters) {                                       \
1385     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                               \
1386     test::Benchmark(#type, FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC,  \
1387                                                     {"FusedBatchNorm"}))     \
1388         .Run(iters);                                                         \
1389   }                                                                          \
1390   BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
1391 
1392 #define BM_FusedConv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type,      \
1393                                            LABEL)                             \
1394   static void BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C,   \
1395                       FW, FH, FC)(int iters) {                                \
1396     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
1397     test::Benchmark(#type,                                                    \
1398                     FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC,          \
1399                                              {"FusedBatchNorm", "Relu"}))     \
1400         .Run(iters);                                                          \
1401   }                                                                           \
1402   BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
1403                     FH, FC));
1404 
1405 // -------------------------------------------------------------------------- //
1406 // Pixel CNN convolutions.
1407 // -------------------------------------------------------------------------- //
1408 
1409 // 1x1 Convolution: MatMulFunctor
1410 
1411 BM_Conv2D(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1412 BM_Conv2D(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1413 BM_Conv2D(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1414 
1415 // 1) BiasAdd {+ Relu}
1416 
1417 BM_Conv2DWithBias(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1418 BM_Conv2DWithBias(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1419 BM_Conv2DWithBias(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1420 
1421 BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1422 BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1423 BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1424 
1425 BM_FusedConv2DWithBias(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1426 BM_FusedConv2DWithBias(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1427 BM_FusedConv2DWithBias(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1428 
1429 BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1430 BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1431 BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1432 
1433 // 2) FusedBatchNorm {+ Relu}
1434 
1435 BM_Conv2DWithBatchNorm(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1436 BM_Conv2DWithBatchNorm(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1437 BM_Conv2DWithBatchNorm(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1438 
1439 BM_Conv2DWithBatchNormAndRelu(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1440 BM_Conv2DWithBatchNormAndRelu(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1441 BM_Conv2DWithBatchNormAndRelu(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1442 
1443 BM_FusedConv2DWithBatchNorm(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1444 BM_FusedConv2DWithBatchNorm(16, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 16");
1445 BM_FusedConv2DWithBatchNorm(32, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 32");
1446 
1447 BM_FusedConv2DWithBatchNormAndRelu(8, 32, 32, 128, 1, 1, 1024, cpu, "1x1 /b 8");
1448 BM_FusedConv2DWithBatchNormAndRelu(16, 32, 32, 128, 1, 1, 1024, cpu,
1449                                    "1x1 /b 16");
1450 BM_FusedConv2DWithBatchNormAndRelu(32, 32, 32, 128, 1, 1, 1024, cpu,
1451                                    "1x1 /b 32");
1452 
1453 // -------------------------------------------------------------------------- //
1454 // 3x3 Convolution: SpatialConvolution
1455 // -------------------------------------------------------------------------- //
1456 
1457 BM_Conv2D(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1458 BM_Conv2D(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1459 BM_Conv2D(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1460 
1461 // 1) BiasAdd {+ Relu}
1462 
1463 BM_Conv2DWithBias(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1464 BM_Conv2DWithBias(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1465 BM_Conv2DWithBias(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1466 
1467 BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1468 BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1469 BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1470 
1471 BM_FusedConv2DWithBias(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1472 BM_FusedConv2DWithBias(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1473 BM_FusedConv2DWithBias(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1474 
1475 BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1476 BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1477 BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1478 
1479 // 2) FusedBatchNorm {+ Relu}
1480 
1481 BM_Conv2DWithBatchNorm(8, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 8");
1482 BM_Conv2DWithBatchNorm(16, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 16");
1483 BM_Conv2DWithBatchNorm(32, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 32");
1484 
1485 BM_Conv2DWithBatchNormAndRelu(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1486 BM_Conv2DWithBatchNormAndRelu(16, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 16");
1487 BM_Conv2DWithBatchNormAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32");
1488 
1489 BM_FusedConv2DWithBatchNorm(8, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 8");
1490 BM_FusedConv2DWithBatchNorm(16, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 16");
1491 BM_FusedConv2DWithBatchNorm(32, 32, 32, 128, 3, 3, 1024, cpu, "1x1 /b 32");
1492 
1493 BM_FusedConv2DWithBatchNormAndRelu(8, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 8");
1494 BM_FusedConv2DWithBatchNormAndRelu(16, 32, 32, 128, 3, 3, 1024, cpu,
1495                                    "3x3 /b 16");
1496 BM_FusedConv2DWithBatchNormAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu,
1497                                    "3x3 /b 32");
1498 
1499 #if GOOGLE_CUDA
1500 // -------------------------------------------------------------------------- //
1501 // 1x1 Convolution
1502 // -------------------------------------------------------------------------- //
1503 
1504 BM_Conv2D(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
1505 BM_Conv2D(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
1506 BM_Conv2D(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
1507 
1508 BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
1509 BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
1510 BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
1511 
1512 BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
1513 BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
1514 BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
1515 
1516 // -------------------------------------------------------------------------- //
1517 // 3x3 Convolution
1518 // -------------------------------------------------------------------------- //
1519 
1520 BM_Conv2D(8, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 8");
1521 BM_Conv2D(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
1522 BM_Conv2D(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
1523 
1524 BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 8");
1525 BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
1526 BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
1527 
1528 BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 8");
1529 BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
1530 BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
1531 #endif
1532 
1533 }  // namespace tensorflow
1534