1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/grappler/utils/grappler_test.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 using ::tensorflow::Scope;
41 using ::tensorflow::ops::Conv2D;
42 using ::tensorflow::ops::Identity;
43 using ::tensorflow::ops::RandomUniform;
44
45 constexpr int kBatchSize = 32;
46 constexpr int kWidth = 10;
47 constexpr int kHeight = 10;
48 constexpr int kDepthIn = 8;
49 constexpr int kKernel = 3;
50 constexpr int kDepthOut = 16;
51
52 // When there is a GPU, we test generic_layout_optimization for the conversion
53 // from NHWC to NCHW format. When there is only CPU, we test the conversion
54 // from NCHW to NHWC format. The following macros help setting tensor shapes,
55 // source and destination format strings, and transpose permutation vectors
56 // appropriately for NHWC -> NCHW conversion (when GPU) and NCHW -> NHWC
57 // conversion (when only CPU).
58
59 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
60 #define DIMS(n, h, w, c) \
61 { n, h, w, c }
62 #define SRC_DATA_FORMAT "NHWC"
63 #define DST_DATA_FORMAT "NCHW"
64 #define DEVICE "GPU"
65 #define REWRITER_CONFIG \
66 RewriterConfig::DEFAULT, RewriterConfig::NO_CONVERSION_ON_CPU
67 #define PERMUTATION_SRC_TO_DST \
68 { 0, 3, 1, 2 }
69 #define PERMUTATION_DST_TO_SRC \
70 { 0, 2, 3, 1 }
71 #else
72 #define DIMS(n, h, w, c) \
73 { n, c, h, w }
74 #define SRC_DATA_FORMAT "NCHW"
75 #define DST_DATA_FORMAT "NHWC"
76 #define DEVICE "CPU"
77 #define REWRITER_CONFIG RewriterConfig::DEFAULT, RewriterConfig::NCHW_TO_NHWC
78 #define PERMUTATION_SRC_TO_DST \
79 { 0, 2, 3, 1 }
80 #define PERMUTATION_DST_TO_SRC \
81 { 0, 3, 1, 2 }
82 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
83
84 template <typename T = float>
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)85 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
86 const string& padding, const string& device) {
87 int batch_size = 8;
88 int input_height = input_size;
89 int input_width = input_size;
90 int input_depth = 3;
91 int filter_count = 2;
92 int stride = 1;
93 TensorShape input_shape(
94 DIMS(batch_size, input_height, input_width, input_depth));
95 Tensor input_data(DataTypeToEnum<T>::value, input_shape);
96 test::FillIota<T>(&input_data, static_cast<T>(1));
97 Output input =
98 ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
99
100 TensorShape filter_shape(
101 {filter_size, filter_size, input_depth, filter_count});
102 Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
103 test::FillIota<T>(&filter_data, static_cast<T>(1));
104 Output filter =
105 ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
106
107 Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
108 filter, DIMS(1, stride, stride, 1), padding,
109 ops::Conv2D::Attrs().DataFormat(SRC_DATA_FORMAT));
110 return conv;
111 }
112
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool dilated,const int input_sizes_length)113 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
114 int filter_size, const string& padding,
115 bool dilated, const int input_sizes_length) {
116 int batch_size = 128;
117 int input_height = input_size;
118 int input_width = input_size;
119 int input_depth = 3;
120 int filter_count = 2;
121 int stride = 1;
122 TensorShape input_sizes_shape({input_sizes_length});
123 Tensor input_data(DT_INT32, input_sizes_shape);
124 if (input_sizes_length == 4) {
125 test::FillValues<int>(
126 &input_data, DIMS(batch_size, input_height, input_width, input_depth));
127 } else {
128 test::FillValues<int>(&input_data, {input_height, input_width});
129 }
130 Output input_sizes =
131 ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
132
133 TensorShape filter_shape(
134 {filter_size, filter_size, input_depth, filter_count});
135 Output filter =
136 ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
137
138 int output_height = input_height;
139 int output_width = input_width;
140 TensorShape output_shape(
141 DIMS(batch_size, output_height, output_width, filter_count));
142 Tensor output_data(DT_FLOAT, output_shape);
143 test::FillIota<float>(&output_data, 1.0f);
144 Output output =
145 ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
146
147 Output conv_backprop_input;
148 Output input_sizes_i =
149 ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
150 ops::Conv2DBackpropInput::Attrs attrs;
151 attrs = attrs.DataFormat(SRC_DATA_FORMAT);
152 if (dilated) {
153 attrs = attrs.Dilations(DIMS(1, 2, 2, 1));
154 }
155 conv_backprop_input = ops::Conv2DBackpropInput(
156 s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
157 DIMS(1, stride, stride, 1), padding, attrs);
158
159 return conv_backprop_input;
160 }
161
162 class GenericLayoutOptimizerTest : public GrapplerTest {
163 protected:
SetUp()164 void SetUp() override {
165 bool gpu_available = GetNumAvailableGPUs() > 0;
166
167 if (gpu_available) {
168 virtual_cluster_ =
169 absl::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
170 } else {
171 DeviceProperties cpu_device;
172 cpu_device.set_type("CPU");
173 cpu_device.set_frequency(1000);
174 cpu_device.set_num_cores(4);
175 cpu_device.set_bandwidth(32);
176 cpu_device.set_l1_cache_size(32 * 1024);
177 cpu_device.set_l2_cache_size(256 * 1024);
178 cpu_device.set_l3_cache_size(4 * 1024 * 1024);
179 cpu_device.set_memory_size(1024 * 1024);
180 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
181 DeviceProperties gpu_device;
182 gpu_device.set_type("GPU");
183 gpu_device.mutable_environment()->insert({"architecture", "6"});
184 virtual_cluster_ =
185 absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device},
186 { "/GPU:1",
187 gpu_device }}));
188 #else
189 virtual_cluster_ =
190 absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device}}));
191 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
192 }
193 TF_ASSERT_OK(virtual_cluster_->Provision());
194 }
195
TearDown()196 void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
197
198 std::unique_ptr<Cluster> virtual_cluster_;
199 };
200
VerifyRegularFaninMatch(const utils::NodeView * node,int port,absl::string_view fanin_name,int fanin_port)201 void VerifyRegularFaninMatch(const utils::NodeView* node, int port,
202 absl::string_view fanin_name, int fanin_port) {
203 ASSERT_GE(node->NumRegularFanins(), port);
204 const auto& fanin = node->GetRegularFanin(port);
205 EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
206 EXPECT_EQ(fanin.index(), fanin_port);
207 }
208
VerifyRegularFanoutMatch(const utils::NodeView * node,int port,absl::string_view fanout_name,int fanout_port)209 void VerifyRegularFanoutMatch(const utils::NodeView* node, int port,
210 absl::string_view fanout_name, int fanout_port) {
211 bool found = false;
212 for (const auto& regular_fanout : node->GetRegularFanout(port)) {
213 if (regular_fanout.node_view()->GetName() == fanout_name &&
214 regular_fanout.index() == fanout_port) {
215 found = true;
216 }
217 }
218 EXPECT_TRUE(found);
219 }
220
VerifyDataFormatAttributeMatch(const utils::NodeView * node,absl::string_view attr_value)221 void VerifyDataFormatAttributeMatch(const utils::NodeView* node,
222 absl::string_view attr_value) {
223 const auto* attr = node->GetAttr("data_format");
224 ASSERT_NE(attr, nullptr);
225 EXPECT_EQ(attr->s(), attr_value);
226 }
227
TEST_F(GenericLayoutOptimizerTest,OptimizeSimpleConv2DGraph)228 TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv2DGraph) {
229 // A simple graph contains 1 Conv2D node, 2 input and 1 output nodes.
230 // Data format is NHWC on GPU, while NCHW on CPU.
231 Scope scope = Scope::NewRootScope();
232
233 auto conv2d = SimpleConv2D(&scope, 4, 2, "VALID", "");
234 auto identity = Identity(scope.WithOpName("Output"), conv2d);
235 GrapplerItem item;
236 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
237
238 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
239 GraphDef output;
240 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
241
242 Status status;
243 utils::GraphView graph_view(&output, &status);
244 TF_ASSERT_OK(status);
245 // The expected optimized graph contains 2 extra sets of Transpose nodes and
246 // has the Conv2D's data_format set to "NCHW" on GPU, while "NHWC" on CPU.
247 auto* input_transpose_node = graph_view.GetNode(
248 absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
249 "-LayoutOptimizer"));
250
251 ASSERT_NE(input_transpose_node, nullptr);
252 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
253 VerifyRegularFaninMatch(input_transpose_node, 0, "Input", 0);
254
255 auto* conv2d_node = graph_view.GetNode("Conv2D");
256 ASSERT_NE(conv2d_node, nullptr);
257 ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
258 VerifyRegularFaninMatch(conv2d_node, 0, input_transpose_node->GetName(), 0);
259 VerifyRegularFaninMatch(conv2d_node, 1, "Filter", 0);
260 VerifyDataFormatAttributeMatch(conv2d_node, DST_DATA_FORMAT);
261
262 auto* output_transpose_node = graph_view.GetNode(
263 absl::StrCat("Conv2D-0-0-Transpose", DST_DATA_FORMAT, "To",
264 SRC_DATA_FORMAT, "-LayoutOptimizer"));
265 ASSERT_NE(output_transpose_node, nullptr);
266 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
267 VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_node->GetName(), 0);
268
269 auto* output_node = graph_view.GetNode("Output");
270 ASSERT_NE(output_node, nullptr);
271 ASSERT_EQ(output_node->NumRegularFanins(), 1);
272 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
273 }
274
TEST_F(GenericLayoutOptimizerTest,PreserveFetch)275 TEST_F(GenericLayoutOptimizerTest, PreserveFetch) {
276 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
277 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
278 auto i = ops::Identity(s.WithOpName("i"), conv);
279 GrapplerItem item;
280 item.fetch.push_back("Conv2D");
281 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
282
283 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
284 GraphDef output;
285 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
286
287 Status status;
288 utils::GraphView graph_view(&output, &status);
289 TF_ASSERT_OK(status);
290 auto* conv_node = graph_view.GetNode("Conv2D");
291 ASSERT_NE(conv_node, nullptr);
292 VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
293 }
294
TEST_F(GenericLayoutOptimizerTest,EmptyDevice)295 TEST_F(GenericLayoutOptimizerTest, EmptyDevice) {
296 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
297 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
298 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
299 GrapplerItem item;
300 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
301
302 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
303 GraphDef output;
304 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
305
306 Status status;
307 utils::GraphView graph_view(&output, &status);
308 TF_ASSERT_OK(status);
309 auto* conv_node = graph_view.GetNode("Conv2D");
310 ASSERT_NE(conv_node, nullptr);
311 VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
312 }
313
TEST_F(GenericLayoutOptimizerTest,GPUDevice)314 TEST_F(GenericLayoutOptimizerTest, GPUDevice) {
315 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
316 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
317 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
318 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
319 auto conv =
320 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:GPU:0");
321 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
322 GrapplerItem item;
323 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
324
325 GenericLayoutOptimizer optimizer;
326 GraphDef output;
327 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
328
329 Status status;
330 utils::GraphView graph_view(&output, &status);
331 TF_ASSERT_OK(status);
332 auto* conv_node = graph_view.GetNode("Conv2D");
333 ASSERT_NE(conv_node, nullptr);
334 VerifyDataFormatAttributeMatch(conv_node, "NCHW");
335 }
336
TEST_F(GenericLayoutOptimizerTest,CPUDevice)337 TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
338 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
339 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
340 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
341 GrapplerItem item;
342 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
343
344 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
345 GraphDef output;
346 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
347
348 Status status;
349 utils::GraphView graph_view(&output, &status);
350 TF_ASSERT_OK(status);
351 auto* conv_node = graph_view.GetNode("Conv2D");
352 ASSERT_NE(conv_node, nullptr);
353 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
354 VerifyDataFormatAttributeMatch(conv_node, "NHWC");
355 #else
356 VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
357 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
358 }
359
TEST_F(GenericLayoutOptimizerTest,NoOptimizeIntegerConvolution)360 TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
361 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
362 auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
363 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
364 GrapplerItem item;
365 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
366
367 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
368 GraphDef output;
369 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
370
371 Status status;
372 utils::GraphView graph_view(&output, &status);
373 TF_ASSERT_OK(status);
374 auto* conv_node = graph_view.GetNode("Conv2D");
375 ASSERT_NE(conv_node, nullptr);
376 VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
377 }
378
TEST_F(GenericLayoutOptimizerTest,Connectivity)379 TEST_F(GenericLayoutOptimizerTest, Connectivity) {
380 Scope scope = Scope::NewRootScope();
381 auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
382 absl::StrCat("/device:", DEVICE, ":0"));
383 auto i1 = ops::Identity(scope.WithOpName("i1"), conv);
384 auto i2 = ops::Identity(scope.WithOpName("i2"), i1);
385 auto i3 = ops::Identity(scope.WithOpName("i3"), i2);
386 GrapplerItem item;
387 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
388 // Make the graph not in topological order to test the handling of multi-hop
389 // connectivity (here we say two nodes are connected if all nodes in the
390 // middle are layout agnostic). If the graph is already in topological order,
391 // the problem is easier, where layout optimizer only needs to check
392 // single-hop connectivity.
393 Status status;
394 utils::GraphView graph_view_original(&item.graph, &status);
395 const int i1_index = graph_view_original.GetNode("i1")->node_index();
396 const int i2_index = graph_view_original.GetNode("i2")->node_index();
397 item.graph.mutable_node()->SwapElements(i1_index, i2_index);
398
399 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
400 GraphDef output;
401 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
402
403 utils::GraphView graph_view(&output, &status);
404 TF_ASSERT_OK(status);
405 auto* node_i2_output = graph_view.GetNode("i2");
406 ASSERT_NE(node_i2_output, nullptr);
407 // Layout optimizer should process i2, as it detects i2 is connected with the
408 // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
409 // directly connected to the Conv2D node.
410 ASSERT_EQ(node_i2_output->NumRegularFanins(), 1);
411 VerifyRegularFaninMatch(node_i2_output, 0, "i1", 0);
412 }
413
TEST_F(GenericLayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)414 TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
415 for (const int input_sizes_length : {2, 4}) {
416 Scope s = Scope::NewRootScope();
417 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
418 input_sizes_length);
419 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
420 GrapplerItem item;
421 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
422
423 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
424 GraphDef output;
425 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
426
427 Status status;
428 utils::GraphView graph_view(&output, &status);
429 TF_ASSERT_OK(status);
430 auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
431 ASSERT_NE(conv2d_backprop_node, nullptr);
432 ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
433 VerifyRegularFaninMatch(
434 conv2d_backprop_node, 0,
435 absl::StrCat("Conv2DBackpropInput-0-DataFormatVecPermute",
436 SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
437 "-LayoutOptimizer"),
438 0);
439 auto* input_sizes_node = graph_view.GetNode(absl::StrCat(
440 "Conv2DBackpropInput-0-DataFormatVecPermute", SRC_DATA_FORMAT, "To",
441 DST_DATA_FORMAT, "-LayoutOptimizer"));
442 ASSERT_NE(input_sizes_node, nullptr);
443 EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
444 ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
445 VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
446 }
447 }
448
TEST_F(GenericLayoutOptimizerTest,Conv2DDataFormatVecPermuteCollapse)449 TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
450 Scope scope =
451 Scope::NewRootScope().WithDevice(absl::StrCat("/device:", DEVICE, ":0"));
452 auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
453 absl::StrCat("/device:", DEVICE, ":0"));
454 auto shape = ops::Shape(scope.WithOpName("shape"), conv);
455 auto value = ops::Const(scope.WithOpName("value"), 0, {});
456 auto fill = ops::Fill(scope.WithOpName("fill"), shape, value);
457 auto i = ops::Identity(scope.WithOpName("i"), fill);
458 GrapplerItem item;
459 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
460
461 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
462 GraphDef output;
463 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
464
465 // Graph before optimization:
466 // input -> conv2d -> shape -> fill -> output
467 //
468 // Graph after expansion:
469 // input -> T -> conv2d -> T' -> T -> shape -> D' -> D -> fill -> T' -> output
470 //
471 // Graph after collapsion:
472 // input -> T -> conv2d -> shape -> fill -> T' -> output
473 Status status;
474 utils::GraphView graph_view(&output, &status);
475 TF_ASSERT_OK(status);
476 auto* conv2d_node = graph_view.GetNode("Conv2D");
477 ASSERT_NE(conv2d_node, nullptr);
478 ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
479 VerifyRegularFaninMatch(
480 conv2d_node, 0,
481 absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
482 "-LayoutOptimizer"),
483 0);
484
485 auto* shape_node = graph_view.GetNode("shape");
486 ASSERT_NE(shape_node, nullptr);
487 ASSERT_EQ(shape_node->NumRegularFanins(), 1);
488 VerifyRegularFaninMatch(shape_node, 0, conv2d_node->GetName(), 0);
489
490 auto* fill_node = graph_view.GetNode("fill");
491 ASSERT_NE(fill_node, nullptr);
492 ASSERT_EQ(fill_node->NumRegularFanins(), 2);
493 VerifyRegularFaninMatch(fill_node, 0, shape_node->GetName(), 0);
494 VerifyRegularFanoutMatch(
495 fill_node, 0,
496 absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
497 "-LayoutOptimizer"),
498 0);
499
500 auto* graph_output = graph_view.GetNode("i");
501 ASSERT_NE(graph_output, nullptr);
502 ASSERT_EQ(graph_output->NumRegularFanins(), 1);
503 VerifyRegularFaninMatch(
504 graph_output, 0,
505 absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
506 "-LayoutOptimizer"),
507 0);
508 }
509
TEST_F(GenericLayoutOptimizerTest,DoNotPruneNonAddedCancellableTransposes)510 TEST_F(GenericLayoutOptimizerTest, DoNotPruneNonAddedCancellableTransposes) {
511 GrapplerItem item;
512 {
513 Scope scope = Scope::NewRootScope().WithDevice(
514 absl::StrCat("/device:", DEVICE, ":0"));
515 auto input = ops::RandomUniform(scope.WithOpName("input"),
516 DIMS(kBatchSize, kHeight, kWidth, kDepthIn),
517 DT_FLOAT);
518 // Permuation for source to destination data format.
519 // GPU: NHWC -> NCHW: {0, 3, 1, 2}
520 // CPU: NCHW -> NHWC: {0, 2, 3, 1}
521 auto input_in_transpose =
522 ops::Transpose(scope.WithOpName("input_in_transpose"), input,
523 ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
524 // Permuation for destination to source data format.
525 // GPU: NCHW -> NHWC: {0, 2, 3, 1}
526 // CPU: NHWC -> NCHW: {0, 3, 1, 2}
527 auto input_out_transpose = ops::Transpose(
528 scope.WithOpName("input_out_transpose"), input_in_transpose,
529 ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
530 Tensor bias_data(DT_FLOAT, TensorShape({kDepthIn}));
531 test::FillIota<float>(&bias_data, 1.0f);
532 auto bias_add = ops::BiasAdd(
533 scope.WithOpName("bias_add"), input_out_transpose, bias_data,
534 ops::BiasAdd::Attrs().DataFormat(SRC_DATA_FORMAT));
535 auto output_in_transpose =
536 ops::Transpose(scope.WithOpName("output_in_transpose"), bias_add,
537 ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
538 auto output_out_transpose = ops::Transpose(
539 scope.WithOpName("output_out_transpose"), output_in_transpose,
540 ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
541 auto output =
542 ops::Identity(scope.WithOpName("output"), output_out_transpose);
543 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
544 }
545
546 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
547 GraphDef output;
548 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
549
550 Status status;
551 utils::GraphView graph_view(&output, &status);
552 TF_ASSERT_OK(status);
553
554 auto* input_node = graph_view.GetNode("input");
555 ASSERT_NE(input_node, nullptr);
556
557 auto* input_in_transpose_node = graph_view.GetNode("input_in_transpose");
558 ASSERT_NE(input_in_transpose_node, nullptr);
559 ASSERT_EQ(input_in_transpose_node->NumRegularFanins(), 2);
560 VerifyRegularFaninMatch(input_in_transpose_node, 0, input_node->GetName(), 0);
561
562 auto* input_out_transpose_node = graph_view.GetNode("input_out_transpose");
563 ASSERT_NE(input_out_transpose_node, nullptr);
564 ASSERT_EQ(input_out_transpose_node->NumRegularFanins(), 2);
565 VerifyRegularFaninMatch(input_out_transpose_node, 0,
566 input_in_transpose_node->GetName(), 0);
567
568 auto* bias_add_in_transpose_node = graph_view.GetNode(
569 absl::StrCat("bias_add-0-Transpose", SRC_DATA_FORMAT, "To",
570 DST_DATA_FORMAT, "-LayoutOptimizer"));
571 ASSERT_NE(bias_add_in_transpose_node, nullptr);
572 ASSERT_EQ(bias_add_in_transpose_node->NumRegularFanins(), 2);
573 VerifyRegularFaninMatch(bias_add_in_transpose_node, 0,
574 input_out_transpose_node->GetName(), 0);
575
576 auto* bias_add_node = graph_view.GetNode("bias_add");
577 ASSERT_NE(bias_add_node, nullptr);
578 ASSERT_EQ(bias_add_node->NumRegularFanins(), 2);
579 VerifyRegularFaninMatch(bias_add_node, 0,
580 bias_add_in_transpose_node->GetName(), 0);
581
582 auto* bias_add_out_transpose_node = graph_view.GetNode(
583 absl::StrCat("bias_add-0-0-Transpose", DST_DATA_FORMAT, "To",
584 SRC_DATA_FORMAT, "-LayoutOptimizer"));
585 ASSERT_NE(bias_add_out_transpose_node, nullptr);
586 ASSERT_EQ(bias_add_out_transpose_node->NumRegularFanins(), 2);
587 VerifyRegularFaninMatch(bias_add_out_transpose_node, 0,
588 bias_add_node->GetName(), 0);
589
590 auto* output_in_transpose_node = graph_view.GetNode("output_in_transpose");
591 ASSERT_NE(output_in_transpose_node, nullptr);
592 ASSERT_EQ(output_in_transpose_node->NumRegularFanins(), 2);
593 VerifyRegularFaninMatch(output_in_transpose_node, 0,
594 bias_add_out_transpose_node->GetName(), 0);
595
596 auto* output_out_transpose_node = graph_view.GetNode("output_out_transpose");
597 ASSERT_NE(output_out_transpose_node, nullptr);
598 ASSERT_EQ(output_out_transpose_node->NumRegularFanins(), 2);
599 VerifyRegularFaninMatch(output_out_transpose_node, 0,
600 output_in_transpose_node->GetName(), 0);
601
602 auto* output_node = graph_view.GetNode("output");
603 ASSERT_NE(output_node, nullptr);
604 ASSERT_EQ(output_node->NumRegularFanins(), 1);
605 VerifyRegularFaninMatch(output_node, 0, output_out_transpose_node->GetName(),
606 0);
607 }
608
TEST_F(GenericLayoutOptimizerTest,CancelTransposeAroundPad)609 TEST_F(GenericLayoutOptimizerTest, CancelTransposeAroundPad) {
610 using test::function::NDef;
611
612 GenericLayoutOptimizer optimizer(
613 RewriterConfig::AGGRESSIVE,
614 RewriterConfig::NCHW_TO_NHWC /* CPU settings*/);
615
616 const Tensor kPermuteNhwcToNchw = test::AsTensor<int32>({0, 3, 1, 2});
617 const Tensor kPermuteNchwToNhwc = test::AsTensor<int32>({0, 2, 3, 1});
618 const Tensor kPad = test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
619
620 GrapplerItem item;
621 item.graph = test::function::GDef({
622 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
623
624 NDef("paddings", "Const", {}, {{"dtype", DT_INT32}, {"value", kPad}}),
625 NDef("perm_nhwc_to_nchw", "Const", {},
626 {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
627 NDef("perm_nchw_to_nhwc", "Const", {},
628 {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
629
630 NDef("transpose_0", "Transpose", {"x", "perm_nhwc_to_nchw"},
631 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
632 NDef("pad", "Pad", {"transpose_0", "paddings"},
633 {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
634 NDef("transpose_1", "Transpose", {"pad", "perm_nchw_to_nhwc"},
635 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
636 NDef("transpose_2", "Transpose", {"pad", "perm_nchw_to_nhwc"},
637 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
638 });
639
640 GraphDef output;
641 TF_EXPECT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
642
643 const Tensor kPermutedPaddings =
644 test::AsTensor<int32>({1, 2, 5, 6, 7, 8, 3, 4}, {4, 2});
645
646 GraphDef expected = test::function::GDef({
647 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
648
649 NDef("paddings", "Const", {},
650 {{"dtype", DT_INT32}, {"value", kPermutedPaddings}}),
651 NDef("perm_nhwc_to_nchw", "Const", {},
652 {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
653 NDef("perm_nchw_to_nhwc", "Const", {},
654 {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
655
656 // Transpose nodes replaced by Identity nodes.
657 NDef("transpose_0", "Identity", {"x"}, {{"T", DT_FLOAT}}),
658 NDef("pad", "Pad", {"transpose_0", "paddings"},
659 {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
660 NDef("transpose_1", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
661 NDef("transpose_2", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
662 });
663
664 CompareGraphs(expected, output);
665
666 Tensor x = GenerateRandomTensor<DT_FLOAT>({2, 6, 6, 8});
667 item.fetch = {"transpose_1", "transpose_2"};
668 item.feed.emplace_back("x", x);
669 auto tensors_expected = EvaluateFetchNodes(item);
670 GrapplerItem optimized = item.WithGraph(std::move(output));
671 auto tensors = EvaluateFetchNodes(optimized);
672 ASSERT_EQ(tensors.size(), 2);
673 ASSERT_EQ(tensors_expected.size(), 2);
674 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
675 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
676 }
677
TEST_F(GenericLayoutOptimizerTest,PreserveInputShapes)678 TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) {
679 using test::function::NDef;
680
681 GenericLayoutOptimizer optimizer(RewriterConfig::AGGRESSIVE);
682
683 AttrValue output_shapes;
684 auto* shape = output_shapes.mutable_list()->add_shape();
685 shape->add_dim()->set_size(-1);
686
687 GrapplerItem item;
688 item.graph = test::function::GDef({NDef(
689 "x", "_Arg", {},
690 {{"T", DT_FLOAT}, {"index", 0}, {"_output_shapes", output_shapes}})});
691
692 GraphDef output;
693 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
694
695 Status status;
696 utils::GraphView graph_view(&output, &status);
697 TF_ASSERT_OK(status);
698
699 auto* arg = graph_view.GetNode("x");
700 ASSERT_NE(arg, nullptr);
701 EXPECT_TRUE(arg->HasAttr("_output_shapes"));
702 }
703
704 // TODO(yanzha): Add more complex Graph for test.
705
706 } // namespace grappler
707 } // namespace tensorflow
708