• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/grappler/utils/grappler_test.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 using ::tensorflow::Scope;
41 using ::tensorflow::ops::Conv2D;
42 using ::tensorflow::ops::Identity;
43 using ::tensorflow::ops::RandomUniform;
44 
45 constexpr int kBatchSize = 32;
46 constexpr int kWidth = 10;
47 constexpr int kHeight = 10;
48 constexpr int kDepthIn = 8;
49 constexpr int kKernel = 3;
50 constexpr int kDepthOut = 16;
51 
52 // When there is a GPU, we test generic_layout_optimization for the conversion
53 // from NHWC to NCHW format. When there is only CPU, we test the conversion
54 // from NCHW to NHWC format. The following macros help setting tensor shapes,
55 // source and destination format strings, and transpose permutation vectors
56 // appropriately for NHWC -> NCHW conversion (when GPU) and NCHW -> NHWC
57 // conversion (when only CPU).
58 
59 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
60 #define DIMS(n, h, w, c) \
61   { n, h, w, c }
62 #define SRC_DATA_FORMAT "NHWC"
63 #define DST_DATA_FORMAT "NCHW"
64 #define DEVICE "GPU"
65 #define REWRITER_CONFIG \
66   RewriterConfig::DEFAULT, RewriterConfig::NO_CONVERSION_ON_CPU
67 #define PERMUTATION_SRC_TO_DST \
68   { 0, 3, 1, 2 }
69 #define PERMUTATION_DST_TO_SRC \
70   { 0, 2, 3, 1 }
71 #else
72 #define DIMS(n, h, w, c) \
73   { n, c, h, w }
74 #define SRC_DATA_FORMAT "NCHW"
75 #define DST_DATA_FORMAT "NHWC"
76 #define DEVICE "CPU"
77 #define REWRITER_CONFIG RewriterConfig::DEFAULT, RewriterConfig::NCHW_TO_NHWC
78 #define PERMUTATION_SRC_TO_DST \
79   { 0, 2, 3, 1 }
80 #define PERMUTATION_DST_TO_SRC \
81   { 0, 3, 1, 2 }
82 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
83 
84 template <typename T = float>
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)85 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
86                     const string& padding, const string& device) {
87   int batch_size = 8;
88   int input_height = input_size;
89   int input_width = input_size;
90   int input_depth = 3;
91   int filter_count = 2;
92   int stride = 1;
93   TensorShape input_shape(
94       DIMS(batch_size, input_height, input_width, input_depth));
95   Tensor input_data(DataTypeToEnum<T>::value, input_shape);
96   test::FillIota<T>(&input_data, static_cast<T>(1));
97   Output input =
98       ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
99 
100   TensorShape filter_shape(
101       {filter_size, filter_size, input_depth, filter_count});
102   Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
103   test::FillIota<T>(&filter_data, static_cast<T>(1));
104   Output filter =
105       ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
106 
107   Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
108                             filter, DIMS(1, stride, stride, 1), padding,
109                             ops::Conv2D::Attrs().DataFormat(SRC_DATA_FORMAT));
110   return conv;
111 }
112 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool dilated,const int input_sizes_length)113 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
114                                  int filter_size, const string& padding,
115                                  bool dilated, const int input_sizes_length) {
116   int batch_size = 128;
117   int input_height = input_size;
118   int input_width = input_size;
119   int input_depth = 3;
120   int filter_count = 2;
121   int stride = 1;
122   TensorShape input_sizes_shape({input_sizes_length});
123   Tensor input_data(DT_INT32, input_sizes_shape);
124   if (input_sizes_length == 4) {
125     test::FillValues<int>(
126         &input_data, DIMS(batch_size, input_height, input_width, input_depth));
127   } else {
128     test::FillValues<int>(&input_data, {input_height, input_width});
129   }
130   Output input_sizes =
131       ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
132 
133   TensorShape filter_shape(
134       {filter_size, filter_size, input_depth, filter_count});
135   Output filter =
136       ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
137 
138   int output_height = input_height;
139   int output_width = input_width;
140   TensorShape output_shape(
141       DIMS(batch_size, output_height, output_width, filter_count));
142   Tensor output_data(DT_FLOAT, output_shape);
143   test::FillIota<float>(&output_data, 1.0f);
144   Output output =
145       ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
146 
147   Output conv_backprop_input;
148   Output input_sizes_i =
149       ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
150   ops::Conv2DBackpropInput::Attrs attrs;
151   attrs = attrs.DataFormat(SRC_DATA_FORMAT);
152   if (dilated) {
153     attrs = attrs.Dilations(DIMS(1, 2, 2, 1));
154   }
155   conv_backprop_input = ops::Conv2DBackpropInput(
156       s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
157       DIMS(1, stride, stride, 1), padding, attrs);
158 
159   return conv_backprop_input;
160 }
161 
162 class GenericLayoutOptimizerTest : public GrapplerTest {
163  protected:
SetUp()164   void SetUp() override {
165     bool gpu_available = GetNumAvailableGPUs() > 0;
166 
167     if (gpu_available) {
168       virtual_cluster_ =
169           absl::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
170     } else {
171       DeviceProperties cpu_device;
172       cpu_device.set_type("CPU");
173       cpu_device.set_frequency(1000);
174       cpu_device.set_num_cores(4);
175       cpu_device.set_bandwidth(32);
176       cpu_device.set_l1_cache_size(32 * 1024);
177       cpu_device.set_l2_cache_size(256 * 1024);
178       cpu_device.set_l3_cache_size(4 * 1024 * 1024);
179       cpu_device.set_memory_size(1024 * 1024);
180 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
181       DeviceProperties gpu_device;
182       gpu_device.set_type("GPU");
183       gpu_device.mutable_environment()->insert({"architecture", "6"});
184       virtual_cluster_ =
185           absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device},
186                                                { "/GPU:1",
187                                                  gpu_device }}));
188 #else
189       virtual_cluster_ =
190           absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device}}));
191 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
192     }
193     TF_ASSERT_OK(virtual_cluster_->Provision());
194   }
195 
TearDown()196   void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
197 
198   std::unique_ptr<Cluster> virtual_cluster_;
199 };
200 
VerifyRegularFaninMatch(const utils::NodeView * node,int port,absl::string_view fanin_name,int fanin_port)201 void VerifyRegularFaninMatch(const utils::NodeView* node, int port,
202                              absl::string_view fanin_name, int fanin_port) {
203   ASSERT_GE(node->NumRegularFanins(), port);
204   const auto& fanin = node->GetRegularFanin(port);
205   EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
206   EXPECT_EQ(fanin.index(), fanin_port);
207 }
208 
VerifyRegularFanoutMatch(const utils::NodeView * node,int port,absl::string_view fanout_name,int fanout_port)209 void VerifyRegularFanoutMatch(const utils::NodeView* node, int port,
210                               absl::string_view fanout_name, int fanout_port) {
211   bool found = false;
212   for (const auto& regular_fanout : node->GetRegularFanout(port)) {
213     if (regular_fanout.node_view()->GetName() == fanout_name &&
214         regular_fanout.index() == fanout_port) {
215       found = true;
216     }
217   }
218   EXPECT_TRUE(found);
219 }
220 
VerifyDataFormatAttributeMatch(const utils::NodeView * node,absl::string_view attr_value)221 void VerifyDataFormatAttributeMatch(const utils::NodeView* node,
222                                     absl::string_view attr_value) {
223   const auto* attr = node->GetAttr("data_format");
224   ASSERT_NE(attr, nullptr);
225   EXPECT_EQ(attr->s(), attr_value);
226 }
227 
TEST_F(GenericLayoutOptimizerTest,OptimizeSimpleConv2DGraph)228 TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv2DGraph) {
229   // A simple graph contains 1 Conv2D node, 2 input and 1 output nodes.
230   // Data format is NHWC on GPU, while NCHW on CPU.
231   Scope scope = Scope::NewRootScope();
232 
233   auto conv2d = SimpleConv2D(&scope, 4, 2, "VALID", "");
234   auto identity = Identity(scope.WithOpName("Output"), conv2d);
235   GrapplerItem item;
236   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
237 
238   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
239   GraphDef output;
240   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
241 
242   Status status;
243   utils::GraphView graph_view(&output, &status);
244   TF_ASSERT_OK(status);
245   // The expected optimized graph contains 2 extra sets of Transpose nodes and
246   // has the Conv2D's data_format set to "NCHW" on GPU, while "NHWC" on CPU.
247   auto* input_transpose_node = graph_view.GetNode(
248       absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
249                    "-LayoutOptimizer"));
250 
251   ASSERT_NE(input_transpose_node, nullptr);
252   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
253   VerifyRegularFaninMatch(input_transpose_node, 0, "Input", 0);
254 
255   auto* conv2d_node = graph_view.GetNode("Conv2D");
256   ASSERT_NE(conv2d_node, nullptr);
257   ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
258   VerifyRegularFaninMatch(conv2d_node, 0, input_transpose_node->GetName(), 0);
259   VerifyRegularFaninMatch(conv2d_node, 1, "Filter", 0);
260   VerifyDataFormatAttributeMatch(conv2d_node, DST_DATA_FORMAT);
261 
262   auto* output_transpose_node = graph_view.GetNode(
263       absl::StrCat("Conv2D-0-0-Transpose", DST_DATA_FORMAT, "To",
264                    SRC_DATA_FORMAT, "-LayoutOptimizer"));
265   ASSERT_NE(output_transpose_node, nullptr);
266   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
267   VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_node->GetName(), 0);
268 
269   auto* output_node = graph_view.GetNode("Output");
270   ASSERT_NE(output_node, nullptr);
271   ASSERT_EQ(output_node->NumRegularFanins(), 1);
272   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
273 }
274 
TEST_F(GenericLayoutOptimizerTest,PreserveFetch)275 TEST_F(GenericLayoutOptimizerTest, PreserveFetch) {
276   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
277   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
278   auto i = ops::Identity(s.WithOpName("i"), conv);
279   GrapplerItem item;
280   item.fetch.push_back("Conv2D");
281   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
282 
283   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
284   GraphDef output;
285   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
286 
287   Status status;
288   utils::GraphView graph_view(&output, &status);
289   TF_ASSERT_OK(status);
290   auto* conv_node = graph_view.GetNode("Conv2D");
291   ASSERT_NE(conv_node, nullptr);
292   VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
293 }
294 
TEST_F(GenericLayoutOptimizerTest,EmptyDevice)295 TEST_F(GenericLayoutOptimizerTest, EmptyDevice) {
296   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
297   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
298   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
299   GrapplerItem item;
300   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
301 
302   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
303   GraphDef output;
304   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
305 
306   Status status;
307   utils::GraphView graph_view(&output, &status);
308   TF_ASSERT_OK(status);
309   auto* conv_node = graph_view.GetNode("Conv2D");
310   ASSERT_NE(conv_node, nullptr);
311   VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
312 }
313 
TEST_F(GenericLayoutOptimizerTest,GPUDevice)314 TEST_F(GenericLayoutOptimizerTest, GPUDevice) {
315 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
316   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
317 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
318   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
319   auto conv =
320       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:GPU:0");
321   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
322   GrapplerItem item;
323   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
324 
325   GenericLayoutOptimizer optimizer;
326   GraphDef output;
327   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
328 
329   Status status;
330   utils::GraphView graph_view(&output, &status);
331   TF_ASSERT_OK(status);
332   auto* conv_node = graph_view.GetNode("Conv2D");
333   ASSERT_NE(conv_node, nullptr);
334   VerifyDataFormatAttributeMatch(conv_node, "NCHW");
335 }
336 
TEST_F(GenericLayoutOptimizerTest,CPUDevice)337 TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
338   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
339   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
340   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
341   GrapplerItem item;
342   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
343 
344   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
345   GraphDef output;
346   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
347 
348   Status status;
349   utils::GraphView graph_view(&output, &status);
350   TF_ASSERT_OK(status);
351   auto* conv_node = graph_view.GetNode("Conv2D");
352   ASSERT_NE(conv_node, nullptr);
353 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
354   VerifyDataFormatAttributeMatch(conv_node, "NHWC");
355 #else
356   VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
357 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
358 }
359 
TEST_F(GenericLayoutOptimizerTest,NoOptimizeIntegerConvolution)360 TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
361   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
362   auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
363   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
364   GrapplerItem item;
365   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
366 
367   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
368   GraphDef output;
369   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
370 
371   Status status;
372   utils::GraphView graph_view(&output, &status);
373   TF_ASSERT_OK(status);
374   auto* conv_node = graph_view.GetNode("Conv2D");
375   ASSERT_NE(conv_node, nullptr);
376   VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
377 }
378 
TEST_F(GenericLayoutOptimizerTest,Connectivity)379 TEST_F(GenericLayoutOptimizerTest, Connectivity) {
380   Scope scope = Scope::NewRootScope();
381   auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
382                            absl::StrCat("/device:", DEVICE, ":0"));
383   auto i1 = ops::Identity(scope.WithOpName("i1"), conv);
384   auto i2 = ops::Identity(scope.WithOpName("i2"), i1);
385   auto i3 = ops::Identity(scope.WithOpName("i3"), i2);
386   GrapplerItem item;
387   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
388   // Make the graph not in topological order to test the handling of multi-hop
389   // connectivity (here we say two nodes are connected if all nodes in the
390   // middle are layout agnostic). If the graph is already in topological order,
391   // the problem is easier, where layout optimizer only needs to check
392   // single-hop connectivity.
393   Status status;
394   utils::GraphView graph_view_original(&item.graph, &status);
395   const int i1_index = graph_view_original.GetNode("i1")->node_index();
396   const int i2_index = graph_view_original.GetNode("i2")->node_index();
397   item.graph.mutable_node()->SwapElements(i1_index, i2_index);
398 
399   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
400   GraphDef output;
401   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
402 
403   utils::GraphView graph_view(&output, &status);
404   TF_ASSERT_OK(status);
405   auto* node_i2_output = graph_view.GetNode("i2");
406   ASSERT_NE(node_i2_output, nullptr);
407   // Layout optimizer should process i2, as it detects i2 is connected with the
408   // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
409   // directly connected to the Conv2D node.
410   ASSERT_EQ(node_i2_output->NumRegularFanins(), 1);
411   VerifyRegularFaninMatch(node_i2_output, 0, "i1", 0);
412 }
413 
TEST_F(GenericLayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)414 TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
415   for (const int input_sizes_length : {2, 4}) {
416     Scope s = Scope::NewRootScope();
417     auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
418                                           input_sizes_length);
419     Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
420     GrapplerItem item;
421     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
422 
423     GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
424     GraphDef output;
425     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
426 
427     Status status;
428     utils::GraphView graph_view(&output, &status);
429     TF_ASSERT_OK(status);
430     auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
431     ASSERT_NE(conv2d_backprop_node, nullptr);
432     ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
433     VerifyRegularFaninMatch(
434         conv2d_backprop_node, 0,
435         absl::StrCat("Conv2DBackpropInput-0-DataFormatVecPermute",
436                      SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
437                      "-LayoutOptimizer"),
438         0);
439     auto* input_sizes_node = graph_view.GetNode(absl::StrCat(
440         "Conv2DBackpropInput-0-DataFormatVecPermute", SRC_DATA_FORMAT, "To",
441         DST_DATA_FORMAT, "-LayoutOptimizer"));
442     ASSERT_NE(input_sizes_node, nullptr);
443     EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
444     ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
445     VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
446   }
447 }
448 
TEST_F(GenericLayoutOptimizerTest,Conv2DDataFormatVecPermuteCollapse)449 TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
450   Scope scope =
451       Scope::NewRootScope().WithDevice(absl::StrCat("/device:", DEVICE, ":0"));
452   auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
453                            absl::StrCat("/device:", DEVICE, ":0"));
454   auto shape = ops::Shape(scope.WithOpName("shape"), conv);
455   auto value = ops::Const(scope.WithOpName("value"), 0, {});
456   auto fill = ops::Fill(scope.WithOpName("fill"), shape, value);
457   auto i = ops::Identity(scope.WithOpName("i"), fill);
458   GrapplerItem item;
459   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
460 
461   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
462   GraphDef output;
463   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
464 
465   // Graph before optimization:
466   // input -> conv2d -> shape -> fill -> output
467   //
468   // Graph after expansion:
469   // input -> T -> conv2d -> T' -> T -> shape -> D' -> D -> fill -> T' -> output
470   //
471   // Graph after collapsion:
472   // input -> T -> conv2d -> shape -> fill -> T' -> output
473   Status status;
474   utils::GraphView graph_view(&output, &status);
475   TF_ASSERT_OK(status);
476   auto* conv2d_node = graph_view.GetNode("Conv2D");
477   ASSERT_NE(conv2d_node, nullptr);
478   ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
479   VerifyRegularFaninMatch(
480       conv2d_node, 0,
481       absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
482                    "-LayoutOptimizer"),
483       0);
484 
485   auto* shape_node = graph_view.GetNode("shape");
486   ASSERT_NE(shape_node, nullptr);
487   ASSERT_EQ(shape_node->NumRegularFanins(), 1);
488   VerifyRegularFaninMatch(shape_node, 0, conv2d_node->GetName(), 0);
489 
490   auto* fill_node = graph_view.GetNode("fill");
491   ASSERT_NE(fill_node, nullptr);
492   ASSERT_EQ(fill_node->NumRegularFanins(), 2);
493   VerifyRegularFaninMatch(fill_node, 0, shape_node->GetName(), 0);
494   VerifyRegularFanoutMatch(
495       fill_node, 0,
496       absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
497                    "-LayoutOptimizer"),
498       0);
499 
500   auto* graph_output = graph_view.GetNode("i");
501   ASSERT_NE(graph_output, nullptr);
502   ASSERT_EQ(graph_output->NumRegularFanins(), 1);
503   VerifyRegularFaninMatch(
504       graph_output, 0,
505       absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
506                    "-LayoutOptimizer"),
507       0);
508 }
509 
TEST_F(GenericLayoutOptimizerTest,DoNotPruneNonAddedCancellableTransposes)510 TEST_F(GenericLayoutOptimizerTest, DoNotPruneNonAddedCancellableTransposes) {
511   GrapplerItem item;
512   {
513     Scope scope = Scope::NewRootScope().WithDevice(
514         absl::StrCat("/device:", DEVICE, ":0"));
515     auto input = ops::RandomUniform(scope.WithOpName("input"),
516                                     DIMS(kBatchSize, kHeight, kWidth, kDepthIn),
517                                     DT_FLOAT);
518     // Permuation for source to destination data format.
519     // GPU: NHWC -> NCHW: {0, 3, 1, 2}
520     // CPU: NCHW -> NHWC: {0, 2, 3, 1}
521     auto input_in_transpose =
522         ops::Transpose(scope.WithOpName("input_in_transpose"), input,
523                        ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
524     // Permuation for destination to source data format.
525     // GPU: NCHW -> NHWC: {0, 2, 3, 1}
526     // CPU: NHWC -> NCHW: {0, 3, 1, 2}
527     auto input_out_transpose = ops::Transpose(
528         scope.WithOpName("input_out_transpose"), input_in_transpose,
529         ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
530     Tensor bias_data(DT_FLOAT, TensorShape({kDepthIn}));
531     test::FillIota<float>(&bias_data, 1.0f);
532     auto bias_add = ops::BiasAdd(
533         scope.WithOpName("bias_add"), input_out_transpose, bias_data,
534         ops::BiasAdd::Attrs().DataFormat(SRC_DATA_FORMAT));
535     auto output_in_transpose =
536         ops::Transpose(scope.WithOpName("output_in_transpose"), bias_add,
537                        ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
538     auto output_out_transpose = ops::Transpose(
539         scope.WithOpName("output_out_transpose"), output_in_transpose,
540         ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
541     auto output =
542         ops::Identity(scope.WithOpName("output"), output_out_transpose);
543     TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
544   }
545 
546   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
547   GraphDef output;
548   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
549 
550   Status status;
551   utils::GraphView graph_view(&output, &status);
552   TF_ASSERT_OK(status);
553 
554   auto* input_node = graph_view.GetNode("input");
555   ASSERT_NE(input_node, nullptr);
556 
557   auto* input_in_transpose_node = graph_view.GetNode("input_in_transpose");
558   ASSERT_NE(input_in_transpose_node, nullptr);
559   ASSERT_EQ(input_in_transpose_node->NumRegularFanins(), 2);
560   VerifyRegularFaninMatch(input_in_transpose_node, 0, input_node->GetName(), 0);
561 
562   auto* input_out_transpose_node = graph_view.GetNode("input_out_transpose");
563   ASSERT_NE(input_out_transpose_node, nullptr);
564   ASSERT_EQ(input_out_transpose_node->NumRegularFanins(), 2);
565   VerifyRegularFaninMatch(input_out_transpose_node, 0,
566                           input_in_transpose_node->GetName(), 0);
567 
568   auto* bias_add_in_transpose_node = graph_view.GetNode(
569       absl::StrCat("bias_add-0-Transpose", SRC_DATA_FORMAT, "To",
570                    DST_DATA_FORMAT, "-LayoutOptimizer"));
571   ASSERT_NE(bias_add_in_transpose_node, nullptr);
572   ASSERT_EQ(bias_add_in_transpose_node->NumRegularFanins(), 2);
573   VerifyRegularFaninMatch(bias_add_in_transpose_node, 0,
574                           input_out_transpose_node->GetName(), 0);
575 
576   auto* bias_add_node = graph_view.GetNode("bias_add");
577   ASSERT_NE(bias_add_node, nullptr);
578   ASSERT_EQ(bias_add_node->NumRegularFanins(), 2);
579   VerifyRegularFaninMatch(bias_add_node, 0,
580                           bias_add_in_transpose_node->GetName(), 0);
581 
582   auto* bias_add_out_transpose_node = graph_view.GetNode(
583       absl::StrCat("bias_add-0-0-Transpose", DST_DATA_FORMAT, "To",
584                    SRC_DATA_FORMAT, "-LayoutOptimizer"));
585   ASSERT_NE(bias_add_out_transpose_node, nullptr);
586   ASSERT_EQ(bias_add_out_transpose_node->NumRegularFanins(), 2);
587   VerifyRegularFaninMatch(bias_add_out_transpose_node, 0,
588                           bias_add_node->GetName(), 0);
589 
590   auto* output_in_transpose_node = graph_view.GetNode("output_in_transpose");
591   ASSERT_NE(output_in_transpose_node, nullptr);
592   ASSERT_EQ(output_in_transpose_node->NumRegularFanins(), 2);
593   VerifyRegularFaninMatch(output_in_transpose_node, 0,
594                           bias_add_out_transpose_node->GetName(), 0);
595 
596   auto* output_out_transpose_node = graph_view.GetNode("output_out_transpose");
597   ASSERT_NE(output_out_transpose_node, nullptr);
598   ASSERT_EQ(output_out_transpose_node->NumRegularFanins(), 2);
599   VerifyRegularFaninMatch(output_out_transpose_node, 0,
600                           output_in_transpose_node->GetName(), 0);
601 
602   auto* output_node = graph_view.GetNode("output");
603   ASSERT_NE(output_node, nullptr);
604   ASSERT_EQ(output_node->NumRegularFanins(), 1);
605   VerifyRegularFaninMatch(output_node, 0, output_out_transpose_node->GetName(),
606                           0);
607 }
608 
TEST_F(GenericLayoutOptimizerTest,CancelTransposeAroundPad)609 TEST_F(GenericLayoutOptimizerTest, CancelTransposeAroundPad) {
610   using test::function::NDef;
611 
612   GenericLayoutOptimizer optimizer(
613       RewriterConfig::AGGRESSIVE,
614       RewriterConfig::NCHW_TO_NHWC /* CPU settings*/);
615 
616   const Tensor kPermuteNhwcToNchw = test::AsTensor<int32>({0, 3, 1, 2});
617   const Tensor kPermuteNchwToNhwc = test::AsTensor<int32>({0, 2, 3, 1});
618   const Tensor kPad = test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
619 
620   GrapplerItem item;
621   item.graph = test::function::GDef({
622       NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
623 
624       NDef("paddings", "Const", {}, {{"dtype", DT_INT32}, {"value", kPad}}),
625       NDef("perm_nhwc_to_nchw", "Const", {},
626            {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
627       NDef("perm_nchw_to_nhwc", "Const", {},
628            {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
629 
630       NDef("transpose_0", "Transpose", {"x", "perm_nhwc_to_nchw"},
631            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
632       NDef("pad", "Pad", {"transpose_0", "paddings"},
633            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
634       NDef("transpose_1", "Transpose", {"pad", "perm_nchw_to_nhwc"},
635            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
636       NDef("transpose_2", "Transpose", {"pad", "perm_nchw_to_nhwc"},
637            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
638   });
639 
640   GraphDef output;
641   TF_EXPECT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
642 
643   const Tensor kPermutedPaddings =
644       test::AsTensor<int32>({1, 2, 5, 6, 7, 8, 3, 4}, {4, 2});
645 
646   GraphDef expected = test::function::GDef({
647       NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
648 
649       NDef("paddings", "Const", {},
650            {{"dtype", DT_INT32}, {"value", kPermutedPaddings}}),
651       NDef("perm_nhwc_to_nchw", "Const", {},
652            {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
653       NDef("perm_nchw_to_nhwc", "Const", {},
654            {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
655 
656       // Transpose nodes replaced by Identity nodes.
657       NDef("transpose_0", "Identity", {"x"}, {{"T", DT_FLOAT}}),
658       NDef("pad", "Pad", {"transpose_0", "paddings"},
659            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
660       NDef("transpose_1", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
661       NDef("transpose_2", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
662   });
663 
664   CompareGraphs(expected, output);
665 
666   Tensor x = GenerateRandomTensor<DT_FLOAT>({2, 6, 6, 8});
667   item.fetch = {"transpose_1", "transpose_2"};
668   item.feed.emplace_back("x", x);
669   auto tensors_expected = EvaluateFetchNodes(item);
670   GrapplerItem optimized = item.WithGraph(std::move(output));
671   auto tensors = EvaluateFetchNodes(optimized);
672   ASSERT_EQ(tensors.size(), 2);
673   ASSERT_EQ(tensors_expected.size(), 2);
674   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
675   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
676 }
677 
TEST_F(GenericLayoutOptimizerTest,PreserveInputShapes)678 TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) {
679   using test::function::NDef;
680 
681   GenericLayoutOptimizer optimizer(RewriterConfig::AGGRESSIVE);
682 
683   AttrValue output_shapes;
684   auto* shape = output_shapes.mutable_list()->add_shape();
685   shape->add_dim()->set_size(-1);
686 
687   GrapplerItem item;
688   item.graph = test::function::GDef({NDef(
689       "x", "_Arg", {},
690       {{"T", DT_FLOAT}, {"index", 0}, {"_output_shapes", output_shapes}})});
691 
692   GraphDef output;
693   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
694 
695   Status status;
696   utils::GraphView graph_view(&output, &status);
697   TF_ASSERT_OK(status);
698 
699   auto* arg = graph_view.GetNode("x");
700   ASSERT_NE(arg, nullptr);
701   EXPECT_TRUE(arg->HasAttr("_output_shapes"));
702 }
703 
704 // TODO(yanzha): Add more complex Graph for test.
705 
706 }  // namespace grappler
707 }  // namespace tensorflow
708