1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19 
20 #include <algorithm>
21 #include <functional>
22 #include <memory>
23 #include <queue>
24 #include <set>
25 #include <stack>
26 #include <tuple>
27 #include <unordered_set>
28 #include <utility>
29 #include <vector>
30 
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/optimization_registry.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/gtl/array_slice.h"
40 #include "tensorflow/core/lib/gtl/map_util.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 #include "tensorflow/core/util/util.h"
45 
46 #include "tensorflow/core/graph/mkl_graph_util.h"
47 #include "tensorflow/core/graph/mkl_layout_pass.h"
48 
49 namespace tensorflow {
50 
51 // This pass implements rewriting of graph to support following scenarios:
52 // (A) Merging nodes in the graph
53 // (B) Rewriting a node in the graph to a new node
54 //     Rewrite happens under following scenario:
55 //     - Propagating Mkl layout as an additional output tensor
56 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
57 //         henceforth.) from every Mkl supported NN layer.
58 //
59 // Example of A : Merging nodes in the graph
60 // -----------------------------------------
61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
62 //
63 //           O = Conv2D(A, B)
64 //           P = BiasAdd(O, C)
65 //
66 // We merge them into Conv2DWithBias as:
67 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
68 //
69 // The meaning of A_m, B_m and C_m is explained in B.1.
70 //
71 // Merge rules:
72 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
73 //    goes to BiasAdd.
74 //  - Also, the intersection of attributes of both the nodes must have same
75 //    values.
76 //  - Both the nodes must have been assigned to same device (if any).
77 //
78 // Example of B.1 : Rewriting nodes to Mkl nodes
79 // ---------------------------------------------
80 // Consider a Relu node. Current definition of Relu node looks like:
81 //
82 //           O = Relu(A)
83 //
84 // Relu has 1 input (A), and 1 output (O).
85 //
86 // This rewrite pass will generate a new graph node for Relu (new node is
87 // called MklRelu) as:
88 //
89 //          O, O_m = MklRelu(A, A_m)
90 //
91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
92 // same as input A of Relu; output O is same as output O of Relu. O_m is the
93 // additional output tensor that will be set by MklRelu, and it represents
94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
95 // metadata for O. A_m is additional input of Relu, and it represents metadata
96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
97 // this metadata from previous node in the graph.
98 //
99 // When a previous node in the graph is an Mkl node, A_m will represent a valid
100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
101 // a dummy Mkl tensor.
102 //
103 // Rewriting rules:
104 //  - Selection of a node for rewriting happens by registering the op type of
105 //    the node with the rewriting pass. If the op type is not registered, then
106 //    all nodes of this op type will not be rewritten.
107 //  - Number of inputs after rewriting:
108 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
109 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
110 //      inputs for the original node.
111 //  - Number of outputs after rewriting:
112 //      Since for every output Tensorflow tensor, the rewritten node generates
113 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
114 //      number of outputs of the original node.
115 //  - Ordering of Tensorflow tensors and Mkl tensors:
116 //      Since every rewritten node generates twice the number of inputs and
117 //      outputs, one could imagine various orderings among Tensorflow tensors
118 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
119 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
120 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
121 //      order. Among N inputs one can get N! permutations.
122 //
123 //      So the question is: which order do we follow? We support 2 types of
124 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
125 //      follows an intuitive order where an Mkl tensor follows the
126 //      corresponding Tensorflow tensor immediately. In the context of the
127 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
128 //      applies to both the inputs and outputs. Contiguous ordering means
129 //      all the Tensorflow tensors are contiguous followed by all the Mkl
130 //      tensors. We use contiguous ordering as default.
131 //
132 // Graph rewrite algorithm:
133 //      Algorithm: Graph Rewrite
134 //      Input: Graph G, Names of the nodes to rewrite and their new names
135 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
136 //      Start:
137 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
138 //        foreach node n in N
139 //        do
140 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
141 //          then
142 //            E = set of <incoming edge and its src_output slot> of n
143 //            E' = {}   // a new set of edges for rewritten node
144 //            foreach <e,s> in E
145 //            do
146 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
147 //                            // tensor as it is
148 //              m = Source node of edge e
149 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
150 //              then
151 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
152 //                                  // tensor as an additional output.
153 //              else
154 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
155 //                                                 // Mkl tensor.
156 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
157 //              fi
158 //            done
159 //            n' = Build_New_Node(G,new_name,E')
160 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
161 //          fi
162 //        done
163 //
164 //      Explanation:
165 //        For graph rewrite, we visit nodes of the input graph in the
166 //        topological sort order. With this ordering, we visit nodes in the
167 //        top-to-bottom fashion. We need this order because while visiting a
168 //        node we want that all of its input nodes are visited and rewritten if
169 //        applicable. This is because if we need to rewrite a given node
170 //        then all of its input nodes need to be fixed (in other words they
171 //        cannot be deleted later.)
172 //
173 //        While visiting a node, we first check if the op type of the node is
174 //        an Mkl op. If it is, then we rewrite that node after constructing
175 //        new inputs to the node. If the op type of the node is not Mkl op,
176 //        then we do not rewrite that node.
177 //
178 // Handling workspace propagation for certain ops:
179 //
180 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
181 //        passing of a workspace from their respective forward ops. Workspace
182 //        tensors provide memory for storing results of intermediate operations
183 //        which are helpful in backward propagation. TensorFlow does not have
184 //        a notion of a workspace and as a result does not allow producing
185 //        additional outputs from these forward ops. For these ops, we need
186 //        to add 2 extra edges between forward ops and their corresponding
187 //        backward ops - the first extra edge carries a workspace tensor and
188 //        the second one carries an Mkl tensor for the workspace tensor.
189 //
190 //        Example:
191 //
192 //        Typical graph for MaxPool and its gradient looks like:
193 //
194 //        A = MaxPool(T)
195 //        B = MaxPoolGrad(X, A, Y)
196 //
197 //        We will transform this graph to propagate the workspace as:
198 //        (with the contiguous ordering)
199 //
200 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
201 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
202 //
203 //        Here W is the workspace tensor. Transformed tensor names with the
204 //        suffix _m are Mkl tensors, and this transformation has been done
205 //        using the algorithm discussed earlier. The transformation for
206 //        workspace propagation only adds extra outputs (W, W_m) for a forward
207 //        op and connects them to the corresponding backward ops.
208 //
209 //        Terms:
210 //
211 //        Forward op name = name of the op in the forward pass
212 //          where a workspace tensor originates (MaxPool in this example)
213 //        Backward op name = name of the op in the backward pass that receives
214 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
215 //        Slot = Position of the output or input slot that will be
216 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
217 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
218 //
219 //        Question:
220 //
221 //        How do we associate a backward op to a forward op? There can be more
222 //        than one op with the exact same name.
223 //
224 //        In this example, we associate MaxPoolGrad with MaxPool. But there
225 //        could be more than one MaxPool ops. To solve this problem, we look
226 //        for _direct_ edge between a forward op and a backward op (tensor A is
227 //        flowing along this edge in the example).
228 //
229 //        How do we transform forward and backward ops when there is no direct
230 //        edge between them? In such a case, we generate dummy tensors for
231 //        workspace tensors. For the example, transformation of MaxPool will
232 //        be exactly same as it would be when there is a direct edge between
233 //        the forward and the backward op --- it is just that MaxPool won't
234 //        generate any workspace tensor. For MaxPoolGrad, the transformation
235 //        will also be same, but instead of connecting W and W_m with the
236 //        outputs of MaxPool, we will produce dummy tensors for them, and we
237 //        will set workspace_enabled attribute to false.
238 //
239 class MklLayoutRewritePass : public GraphOptimizationPass {
240  public:
MklLayoutRewritePass()241   MklLayoutRewritePass() {
242     // NOTE: names are alphabetically sorted.
243     csinfo_.addn = "AddN";
244     csinfo_.avg_pool = "AvgPool";
245     csinfo_.avg_pool_grad = "AvgPoolGrad";
246     csinfo_.avg_pool3d = "AvgPool3D";
247     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
248     csinfo_.bias_add = "BiasAdd";
249     csinfo_.bias_add_grad = "BiasAddGrad";
250     csinfo_.concat = "Concat";
251     csinfo_.concatv2 = "ConcatV2";
252     csinfo_.conv2d = "Conv2D";
253     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
254     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
255     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
256     csinfo_.conv2d_grad_filter_with_bias =
257         "__MklDummyConv2DBackpropFilterWithBias";
258     csinfo_.conv3d = "Conv3D";
259     csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
260     csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
261     csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
262     csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
263     csinfo_.depthwise_conv2d_grad_filter =
264         "DepthwiseConv2dNativeBackpropFilter";
265     csinfo_.fused_batch_norm = "FusedBatchNorm";
266     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
267     csinfo_.fused_conv2d = "_FusedConv2D";
268     csinfo_.identity = "Identity";
269     csinfo_.leakyrelu = "LeakyRelu";
270     csinfo_.leakyrelu_grad = "LeakyReluGrad";
271     csinfo_.lrn = "LRN";
272     csinfo_.lrn_grad = "LRNGrad";
273     csinfo_.matmul = "MatMul";
274     csinfo_.max_pool = "MaxPool";
275     csinfo_.max_pool_grad = "MaxPoolGrad";
276     csinfo_.max_pool3d = "MaxPool3D";
277     csinfo_.max_pool3d_grad = "MaxPool3DGrad";
278     csinfo_.mkl_conv2d = "_MklConv2D";
279     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
280     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
281     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
282     csinfo_.mkl_conv2d_grad_filter_with_bias =
283         "_MklConv2DBackpropFilterWithBias";
284     csinfo_.mkl_depthwise_conv2d_grad_input =
285         "_MklDepthwiseConv2dNativeBackpropInput";
286     csinfo_.mkl_depthwise_conv2d_grad_filter =
287         "_MklDepthwiseConv2dNativeBackpropFilter";
288     csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
289     csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
290     csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
291     csinfo_.pad = "Pad";
292     csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
293     csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
294     csinfo_.quantized_avg_pool = "QuantizedAvgPool";
295     csinfo_.quantized_concatv2 = "QuantizedConcatV2";
296     csinfo_.quantized_conv2d = "QuantizedConv2D";
297     csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
298     csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
299     csinfo_.quantized_conv2d_with_bias_and_requantize =
300         "QuantizedConv2DWithBiasAndRequantize";
301     csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
302     csinfo_.quantized_conv2d_and_relu_and_requantize =
303         "QuantizedConv2DAndReluAndRequantize";
304     csinfo_.quantized_conv2d_with_bias_and_relu =
305         "QuantizedConv2DWithBiasAndRelu";
306     csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
307         "QuantizedConv2DWithBiasAndReluAndRequantize";
308     csinfo_.quantized_max_pool = "QuantizedMaxPool";
309     csinfo_.quantized_conv2d_with_bias_sum_and_relu =
310         "QuantizedConv2DWithBiasSumAndRelu";
311     csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
312         "QuantizedConv2DWithBiasSumAndReluAndRequantize";
313     csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
314         "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
315     csinfo_.relu = "Relu";
316     csinfo_.relu_grad = "ReluGrad";
317     csinfo_.relu6 = "Relu6";
318     csinfo_.relu6_grad = "Relu6Grad";
319     csinfo_.requantize = "Requantize";
320     csinfo_.tanh = "Tanh";
321     csinfo_.tanh_grad = "TanhGrad";
322     csinfo_.reshape = "Reshape";
323     csinfo_.slice = "Slice";
324     csinfo_.softmax = "Softmax";
325     csinfo_.split = "Split";
326     csinfo_.transpose = "Transpose";
327     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
328     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
329     // MklInputConversion op is added before it.
330     csinfo_.add = "Add";
331     csinfo_.maximum = "Maximum";
332     csinfo_.mul = "Mul";
333     csinfo_.squared_difference = "SquaredDifference";
334     csinfo_.sub = "Sub";
335     // End - element-wise ops. See note above.
336 
337     // NOTE: names are alphabetically sorted.
338     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
339                       CopyAttrsAddN, AddNRewrite});
340     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
341                       CopyAttrsDataType, AlwaysRewrite});
342     rinfo_.push_back({csinfo_.avg_pool,
343                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
344                       CopyAttrsPooling, AlwaysRewrite});
345     rinfo_.push_back({csinfo_.avg_pool_grad,
346                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
347                       CopyAttrsPooling, AlwaysRewrite});
348     rinfo_.push_back({csinfo_.avg_pool3d,
349                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
350                       CopyAttrsPooling, AlwaysRewrite});
351     rinfo_.push_back({csinfo_.avg_pool3d_grad,
352                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
353                       CopyAttrsPooling, AlwaysRewrite});
354     rinfo_.push_back({csinfo_.concat,
355                       mkl_op_registry::GetMklOpName(csinfo_.concat),
356                       CopyAttrsConcat, AlwaysRewrite});
357     rinfo_.push_back({csinfo_.concatv2,
358                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
359                       CopyAttrsConcatV2, AlwaysRewrite});
360     rinfo_.push_back({csinfo_.conv2d,
361                       mkl_op_registry::GetMklOpName(csinfo_.conv2d),
362                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
363     rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
364                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
365     rinfo_.push_back({csinfo_.conv2d_grad_filter,
366                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
367                       CopyAttrsConv, AlwaysRewrite});
368     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
369                       csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
370                       AlwaysRewrite});
371     rinfo_.push_back({csinfo_.conv2d_grad_input,
372                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
373                       CopyAttrsConv, AlwaysRewrite});
374     rinfo_.push_back({csinfo_.conv3d,
375                       mkl_op_registry::GetMklOpName(csinfo_.conv3d),
376                       CopyAttrsConvCheckConstFilter, AlwaysRewrite});
377     rinfo_.push_back({csinfo_.conv3d_grad_filter,
378                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
379                       CopyAttrsConv, AlwaysRewrite});
380     rinfo_.push_back({csinfo_.conv3d_grad_input,
381                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
382                       CopyAttrsConv, AlwaysRewrite});
383     rinfo_.push_back({csinfo_.depthwise_conv2d,
384                       mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
385                       CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite});
386     rinfo_.push_back(
387         {csinfo_.depthwise_conv2d_grad_input,
388          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
389          CopyAttrsConv2DDepthwise, AlwaysRewrite});
390     rinfo_.push_back(
391         {csinfo_.depthwise_conv2d_grad_filter,
392          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
393          CopyAttrsConv2DDepthwise, AlwaysRewrite});
394     rinfo_.push_back({csinfo_.fused_batch_norm,
395                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
396                       CopyAttrsFusedBatchNorm, AlwaysRewrite});
397     rinfo_.push_back(
398         {csinfo_.fused_batch_norm_grad,
399          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
400          CopyAttrsFusedBatchNorm, AlwaysRewrite});
401     rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
402                       CopyAttrsFusedConv2D, FusedConv2DRewrite});
403     rinfo_.push_back({csinfo_.identity,
404                       mkl_op_registry::GetMklOpName(csinfo_.identity),
405                       CopyAttrsDataType, AlwaysRewrite});
406     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
407                       CopyAttrsLRN, LrnRewrite});
408     rinfo_.push_back({csinfo_.lrn_grad,
409                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
410                       CopyAttrsLRN, LrnGradRewrite});
411     rinfo_.push_back({csinfo_.leakyrelu,
412                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
413                       CopyAttrsLeakyRelu, LeakyReluRewrite});
414     rinfo_.push_back({csinfo_.leakyrelu_grad,
415                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
416                       CopyAttrsLeakyRelu, LeakyReluRewrite});
417     rinfo_.push_back({csinfo_.max_pool,
418                       mkl_op_registry::GetMklOpName(csinfo_.max_pool),
419                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
420     rinfo_.push_back({csinfo_.max_pool_grad,
421                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
422                       CopyAttrsPooling, MaxpoolGradRewrite});
423     rinfo_.push_back({csinfo_.max_pool3d,
424                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
425                       CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
426     rinfo_.push_back({csinfo_.max_pool3d_grad,
427                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
428                       CopyAttrsPooling, AlwaysRewrite});
429     rinfo_.push_back({csinfo_.maximum,
430                       mkl_op_registry::GetMklOpName(csinfo_.maximum),
431                       CopyAttrsDataType, AlwaysRewrite});
432     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
433                       CopyAttrsDataType, AlwaysRewrite});
434     rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
435                       CopyAttrsPadWithConv2D, AlwaysRewrite});
436     rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
437                       csinfo_.mkl_pad_with_fused_conv2d,
438                       CopyAttrsPadWithFusedConv2D, AlwaysRewrite});
439     rinfo_.push_back({csinfo_.quantized_avg_pool,
440                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
441                       CopyAttrsQuantizedPooling, AlwaysRewrite});
442     rinfo_.push_back({csinfo_.quantized_concatv2,
443                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
444                       CopyAttrsConcatV2, AlwaysRewrite});
445     rinfo_.push_back({csinfo_.quantized_conv2d,
446                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
447                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
448     rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
449                       mkl_op_registry::GetMklOpName(
450                           csinfo_.quantized_conv2d_with_requantize),
451                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
452     rinfo_.push_back(
453         {csinfo_.quantized_conv2d_with_bias,
454          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
455          CopyAttrsQuantizedConv2D, AlwaysRewrite});
456     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
457                       mkl_op_registry::GetMklOpName(
458                           csinfo_.quantized_conv2d_with_bias_and_requantize),
459                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
460     rinfo_.push_back(
461         {csinfo_.quantized_conv2d_and_relu,
462          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
463          CopyAttrsQuantizedConv2D, AlwaysRewrite});
464     rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
465                       mkl_op_registry::GetMklOpName(
466                           csinfo_.quantized_conv2d_and_relu_and_requantize),
467                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
468     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
469                       mkl_op_registry::GetMklOpName(
470                           csinfo_.quantized_conv2d_with_bias_and_relu),
471                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
472     rinfo_.push_back(
473         {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
474          mkl_op_registry::GetMklOpName(
475              csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
476          CopyAttrsQuantizedConv2D, AlwaysRewrite});
477     rinfo_.push_back({csinfo_.quantized_max_pool,
478                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
479                       CopyAttrsQuantizedPooling, AlwaysRewrite});
480     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
481                       mkl_op_registry::GetMklOpName(
482                           csinfo_.quantized_conv2d_with_bias_sum_and_relu),
483                       CopyAttrsQuantizedConv2D, AlwaysRewrite});
484     rinfo_.push_back(
485         {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
486          mkl_op_registry::GetMklOpName(
487              csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
488          CopyAttrsQuantizedConv2D, AlwaysRewrite});
489     rinfo_.push_back(
490         {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
491          mkl_op_registry::GetMklOpName(
492              csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
493          CopyAttrsQuantizedConv2D, AlwaysRewrite});
494     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
495                       CopyAttrsDataType, AlwaysRewrite});
496     rinfo_.push_back({csinfo_.relu_grad,
497                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
498                       CopyAttrsDataType, AlwaysRewrite});
499     rinfo_.push_back({csinfo_.relu6,
500                       mkl_op_registry::GetMklOpName(csinfo_.relu6),
501                       CopyAttrsDataType, AlwaysRewrite});
502     rinfo_.push_back({csinfo_.relu6_grad,
503                       mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
504                       CopyAttrsDataType, AlwaysRewrite});
505     rinfo_.push_back({csinfo_.requantize,
506                       mkl_op_registry::GetMklOpName(csinfo_.requantize),
507                       CopyAttrsRequantize, AlwaysRewrite});
508     /*
509     rinfo_.push_back({csinfo_.tanh,
510                       mkl_op_registry::GetMklOpName(csinfo_.tanh),
511                       CopyAttrsDataType, AlwaysRewrite});
512     rinfo_.push_back({csinfo_.tanh_grad,
513                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
514                       CopyAttrsDataType, AlwaysRewrite});
515     */
516     rinfo_.push_back({csinfo_.reshape,
517                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
518                       CopyAttrsReshape, AlwaysRewrite});
519     rinfo_.push_back({csinfo_.slice,
520                       mkl_op_registry::GetMklOpName(csinfo_.slice),
521                       CopyAttrsSlice, AlwaysRewrite});
522     rinfo_.push_back({csinfo_.softmax,
523                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
524                       CopyAttrsDataType, AlwaysRewrite});
525 
526     rinfo_.push_back({csinfo_.squared_difference,
527                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
528                       CopyAttrsDataType, AlwaysRewrite});
529     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
530                       CopyAttrsDataType, AlwaysRewrite});
531 
532     // Add info about which ops to add workspace edge to and the slots.
533     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
534     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
535     wsinfo_.push_back(
536         {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
537 
538     // Add a rule for merging nodes
539     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
540                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
541 
542     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
543                       csinfo_.conv2d_grad_filter_with_bias,
544                       GetConv2DBackpropFilterOrBiasAddGrad});
545     // Merge Pad and Conv2d, only if the pad op is "Pad"
546     // Doesn't merge if pad op is "PadV2" or "MirrorPad"
547     minfo_.push_back(
548         {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
549 
550     minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
551                       csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
552 
553     // The fusion patterns in "finfo_" that show up first will get applied
554     // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
555     // A->B->C->D to ABCD}, since the first gets applied first, the final
556     // graph will be ABC->D.
557 
558     //
559     // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
560     // (NHWC) + Transpose (NHWC->
561     // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
562     // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
563     // while "fusion" is for 3+ nodes situation.
564     //
565 
566     // Transpose + Conv2d + Transpose:
567     std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
568                                           NCHW::dim::W, NCHW::dim::C};
569     std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
570                                           NHWC::dim::H, NHWC::dim::W};
571     auto CheckForTransposeToNHWC =
572         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
573     auto CheckForConv2dOp =
574         std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
575     auto CheckForTransposeToNCHW =
576         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
577     auto FuseConv2D =
578         std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
579                   std::placeholders::_2, std::placeholders::_3, "NCHW");
580     finfo_.push_back(
581         {"transpose-elimination for Conv2D",
582          {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
583          // CheckForMklOp
584          FuseConv2D,
585          CopyAttrsConv});
586   }
587 
588   // Standard interface to run pass
589   Status Run(const GraphOptimizationPassOptions& options);
590 
591   // Helper function which does most of heavy lifting for rewriting
592   // Mkl nodes to propagate Mkl tensor as additional output
593   //
594   // Extracts common functionality between Run public interface and
595   // test interface.
596   //
597   // @return true, if and only if graph is mutated; false otherwise.
598   bool RunPass(std::unique_ptr<Graph>* g);
599 
600   /// Structure to specify the name of an original node, its new name after
601   /// rewrite, the number of inputs to the original node, the function to
602   /// be used to copy attributes for the op, and the rule (if any) which
603   /// must hold for rewriting the node
604   typedef struct {
605     string name;      // Original name of op of the node in the graph
606     string new_name;  // New name of the op of the node in the graph
607     // A function handler to copy attributes from an old node to a new node.
608     std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
609     // A rule under which to rewrite this node
610     std::function<bool(const Node*)> rewrite_rule;
611   } RewriteInfo;
612 
613   /// Structure to specify a forward op, a backward op, and the slot numbers
614   /// in the forward and backward ops where we will add a workspace edge.
615   typedef struct {
616     string fwd_op;    // Name of a forward op in the graph
617     string bwd_op;    // Name of a backward op in the graph
618     int fwd_slot;     // Output slot in the forward op node where actual
619                       // output tensor resides
620     int bwd_slot;     // Input slot in the backward op node where actual
621                       // input tensor resides
622     int ws_fwd_slot;  // Output slot in the forward op node where workspace
623                       // edge is added
624     int ws_bwd_slot;  // Input slot in the backward op node where workspace
625                       // edge is added
626   } WorkSpaceInfo;
627 
628   /// Structure to specify information used in node merge of 2 operators
629   typedef struct {
630     string op1;       // Node string for one operator.
631     string op2;       // Node string for second operator.
632     string new_node;  // Name of the node after merge
633     // Function that enables user of the node merger to specify how to find
634     // second operator given the first operator.
635     std::function<Node*(const Node*)> get_node_to_be_merged;
636   } MergeInfo;
637 
638   // Structure to specify information used in node fusion of 3+ operators
639   typedef struct {
640     std::string pattern_name;  // Name to describe this pattern, such as
641                                // "Transpose_Mklop_Transpose".
642     std::vector<std::function<bool(const Node*)> >
643         node_checkers;  // Extra restriction checker for these ops
644     std::function<Status(
645         std::unique_ptr<Graph>*, std::vector<Node*>&,
646         std::function<void(const Node*, NodeBuilder* nb, bool)>)>
647         fuse_func;
648     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
649   } FusionInfo;
650 
651   //
652   // Dimension indices for 2D tensor.
653   //
654   struct NCHW {
655     enum dim { N = 0, C = 1, H = 2, W = 3 };
656   };
657 
658   struct NHWC {
659     enum dim { N = 0, H = 1, W = 2, C = 3 };
660   };
661 
662   //
663   // dimension indices for 3D tensor.
664   //
665   struct NCDHW {
666     enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
667   };
668 
669   struct NDHWC {
670     enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
671   };
672 
673   /// Structure to store all constant strings
674   /// NOTE: names are alphabetically sorted.
675   typedef struct {
676     string addn;
677     string add;
678     string avg_pool;
679     string avg_pool_grad;
680     string avg_pool3d;
681     string avg_pool3d_grad;
682     string bias_add;
683     string bias_add_grad;
684     string concat;
685     string concatv2;
686     string conv2d;
687     string conv2d_with_bias;
688     string conv2d_grad_input;
689     string conv2d_grad_filter;
690     string conv2d_grad_filter_with_bias;
691     string conv3d;
692     string conv3d_grad_input;
693     string conv3d_grad_filter;
694     string depthwise_conv2d;
695     string depthwise_conv2d_grad_input;
696     string depthwise_conv2d_grad_filter;
697     string fused_batch_norm;
698     string fused_batch_norm_grad;
699     string fused_conv2d;
700     string identity;
701     string leakyrelu;
702     string leakyrelu_grad;
703     string lrn;
704     string lrn_grad;
705     string matmul;
706     string max_pool;
707     string max_pool_grad;
708     string max_pool3d;
709     string max_pool3d_grad;
710     string maximum;
711     string mkl_conv2d;
712     string mkl_conv2d_grad_input;
713     string mkl_conv2d_grad_filter;
714     string mkl_conv2d_grad_filter_with_bias;
715     string mkl_conv2d_with_bias;
716     string mkl_depthwise_conv2d_grad_input;
717     string mkl_depthwise_conv2d_grad_filter;
718     string mkl_fused_conv2d;
719     string mkl_pad_with_conv2d;
720     string mkl_pad_with_fused_conv2d;
721     string mul;
722     string pad;
723     string pad_with_conv2d;
724     string pad_with_fused_conv2d;
725     string quantized_avg_pool;
726     string quantized_conv2d;
727     string quantized_conv2d_with_requantize;
728     string quantized_conv2d_with_bias;
729     string quantized_conv2d_with_bias_and_requantize;
730     string quantized_conv2d_and_relu;
731     string quantized_conv2d_and_relu_and_requantize;
732     string quantized_conv2d_with_bias_and_relu;
733     string quantized_conv2d_with_bias_and_relu_and_requantize;
734     string quantized_concatv2;
735     string quantized_max_pool;
736     string quantized_conv2d_with_bias_sum_and_relu;
737     string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
738     string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
739     string relu;
740     string relu_grad;
741     string relu6;
742     string relu6_grad;
743     string requantize;
744     string tanh;
745     string tanh_grad;
746     string transpose;
747     string reshape;
748     string slice;
749     string softmax;
750     string split;
751     string squared_difference;
752     string sub;
753   } ConstStringsInfo;
754 
755  private:
756   /// Maintain info about nodes to rewrite
757   std::vector<RewriteInfo> rinfo_;
758 
759   /// Maintain info about nodes to add workspace edge
760   std::vector<WorkSpaceInfo> wsinfo_;
761 
762   /// Maintain info about nodes to be merged
763   std::vector<MergeInfo> minfo_;
764 
765   /// Maintain info about nodes to be fused
766   std::vector<FusionInfo> finfo_;
767 
768   /// Maintain structure of constant strings
769   static ConstStringsInfo csinfo_;
770 
771  private:
772   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
773   // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const774   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
775     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
776   }
777 
778   // Get length of a list in 'n' if 'arg' is of list type. Refer to
779   // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,Node * n)780   inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
781     CHECK_EQ(ArgIsList(arg), true);
782     int N = 0;
783     const string attr_name = !arg.type_list_attr().empty()
784                                  ? arg.type_list_attr()
785                                  : arg.number_attr();
786     if (!arg.type_list_attr().empty()) {
787       std::vector<DataType> value;
788       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
789       N = value.size();
790     } else {
791       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
792     }
793     return N;
794   }
795 
796   // Can op represented by node 'n' run on DEVICE_CPU?
797   // Op can run on CPU with MKL if the runtime assigned device or the
798   // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)799   bool CanOpRunOnCPUDevice(const Node* n) {
800     bool result = true;
801     string reason;
802 
803     // Substring that should be checked for in device name for CPU device.
804     const char* const kCPUDeviceSubStr = "CPU";
805 
806     // If Op has been specifically assigned to a non-CPU device, then No.
807     if (!n->assigned_device_name().empty() &&
808         !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
809       result = false;
810       reason = "Op has been assigned a runtime device that is not CPU.";
811     }
812 
813     // If user has specifically assigned this op to a non-CPU device, then No.
814     if (!n->def().device().empty() &&
815         !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
816       result = false;
817       reason = "User has assigned a device that is not CPU.";
818     }
819 
820     if (result == false) {
821       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
822               << n->type_string() << ", reason: " << reason;
823     }
824 
825     // Otherwise Yes.
826     return result;
827   }
828 
829   // Return a node that can be merged with input node 'n'
830   //
831   // @return pointer to the node if we can find such a
832   // node. Otherwise, it returns nullptr.
833   Node* CheckForNodeMerge(const Node* n) const;
834 
835   // Merge node 'm' with node 'n'.
836   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
837   // Conv2DBackpropFilter.
838   //
839   // Input nodes m and n may be deleted if the call to
840   // this function is successful. Attempt to use the pointers
841   // after the call to function may result in undefined behaviors.
842   //
843   // @input g - input graph, m - graph node, n - graph node to be merged with m
844   // @return Status::OK(), if merging is successful and supported.
845   //         Returns appropriate Status error code otherwise.
846   //         Graph is updated in case nodes are merged. Otherwise, it is
847   //         not updated.
848   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
849 
850   // Helper function to merge different nodes
851   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
852   Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
853   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
854                                                   Node* m, Node* n);
855 
856   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
857   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
858   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
859   // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)860   static Node* GetConv2DOrBiasAdd(const Node* m) {
861     CHECK_NOTNULL(m);
862     Node* n = nullptr;
863 
864     DataType T_m;
865     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
866 
867     // Don't try to merge if datatype is not DT_FLOAT
868     if (T_m != DT_FLOAT) return n;
869 
870     if (m->type_string() == csinfo_.bias_add) {
871       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
872       TF_CHECK_OK(m->input_node(0, &n));
873     } else {
874       CHECK_EQ(m->type_string(), csinfo_.conv2d);
875       // Go over all output edges and search for BiasAdd Node.
876       // 0th input of BiasAdd is Conv2D.
877       for (const Edge* e : m->out_edges()) {
878         if (!e->IsControlEdge() &&
879             e->dst()->type_string() == csinfo_.bias_add &&
880             e->dst_input() == 0) {
881           n = e->dst();
882           break;
883         }
884       }
885     }
886 
887     if (n == nullptr) {
888       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
889               << "Conv2D and BiasAdd node for merging. Input node: "
890               << m->DebugString();
891     }
892 
893     return n;
894   }
895 
896   // Find Pad or Conv2D node that can be merged with input node 'm'.
897   // If input 'm' is Pad, then check if there exists Conv2D node that can be
898   // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
899   // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)900   static Node* GetPadOrConv2D(const Node* m) {
901     DCHECK(m);
902     Node* n = nullptr;
903 
904     DataType T_m;
905     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
906 
907     // Don't try to merge if datatype is not DT_FLOAT
908     if (T_m != DT_FLOAT) return n;
909 
910     const Node* conv_node;
911     if (m->type_string() == csinfo_.pad) {
912       // If m is Pad, then Conv2D is the output of Pad.
913       for (const Edge* e : m->out_edges()) {
914         if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
915           n = e->dst();
916           conv_node = n;
917           break;
918         }
919       }
920     } else {
921       DCHECK_EQ(m->type_string(), csinfo_.conv2d);
922       // If m is conv2D, Go over all input edges
923       // and search for Pad  Node.
924       for (const Edge* e : m->in_edges()) {
925         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
926           n = e->src();
927           conv_node = m;
928           break;
929         }
930       }
931     }
932     // Check if only VALID type of padding is used
933     // or not.
934     if (n != nullptr) {
935       string padding;
936       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
937       if (padding != "VALID")
938         // Then do not merge.
939         // Only VALID type of padding in conv op can be
940         // merged with Pad op.
941         n = nullptr;
942     } else {
943       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
944               << "Pad and Conv2D node for merging. Input node: "
945               << m->DebugString();
946     }
947 
948     return n;
949   }
950 
951   // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
952   // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
953   // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
954   // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)955   static Node* GetPadOrFusedConv2D(const Node* m) {
956     DCHECK(m);
957     Node* n = nullptr;
958 
959     const Node* conv_node;
960     if (m->type_string() == csinfo_.pad) {
961       // If m is Pad, then _FusedConv2D is the output of Pad.
962       for (const Edge* e : m->out_edges()) {
963         if (!e->IsControlEdge() &&
964             e->dst()->type_string() == csinfo_.fused_conv2d) {
965           n = e->dst();
966           conv_node = n;
967           break;
968         }
969       }
970     } else {
971       DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
972       // If m is _FusedConv2D, Go over all input edges
973       // and search for Pad node.
974       for (const Edge* e : m->in_edges()) {
975         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
976           n = e->src();
977           conv_node = m;
978           break;
979         }
980       }
981     }
982     // Check if only VALID type of padding is used or not.
983     if (n != nullptr) {
984       string padding;
985       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
986       if (padding != "VALID") {
987         // Then do not merge.
988         n = nullptr;
989         VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
990                 << "nodes but cannot merge them. Only conv ops with padding "
991                 << "type VALID can be merged with Pad op Input node: "
992                 << m->DebugString();
993       }
994     } else {
995       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
996               << "Pad and _FusedConv2D node for merging. Input node: "
997               << m->DebugString();
998     }
999 
1000     return n;
1001   }
1002 
1003   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1004   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1005   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1006   // then check if there exists Conv2DBackpropFilter node that can be merged
1007   // with 'm'.
1008   //
1009   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1010   // would look like:
1011   //
1012   // _ = Conv2DBackpropFilter(F, _, G)
1013   // _ = BiasAddGrad(G)
1014   //
1015   // So 1st input of BiasAddGrad connects with 3rd input of
1016   // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1017   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1018     CHECK_NOTNULL(m);
1019     Node* n = nullptr;
1020 
1021     DataType T_m;
1022     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1023 
1024     // Don't try to merge if datatype is not DT_FLOAT
1025     if (T_m != DT_FLOAT) return n;
1026 
1027     if (m->type_string() == csinfo_.bias_add_grad) {
1028       // Get 1st input 'g' of BiasAddGrad.
1029       Node* g = nullptr;
1030       TF_CHECK_OK(m->input_node(0, &g));
1031       // Now traverse all outgoing edges from g that have destination node as
1032       // Conv2DBackpropFilter.
1033       for (const Edge* e : g->out_edges()) {
1034         if (!e->IsControlEdge() &&
1035             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1036             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1037           n = e->dst();
1038           break;
1039         }
1040       }
1041     } else {
1042       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1043       // Get 3rd input 'g' of Conv2DBackpropFilter.
1044       Node* g = nullptr;
1045       TF_CHECK_OK(m->input_node(2, &g));
1046       // Now traverse all outgoing edges from g that have destination node as
1047       // BiasAddGrad.
1048       for (const Edge* e : g->out_edges()) {
1049         if (!e->IsControlEdge() &&
1050             e->dst()->type_string() == csinfo_.bias_add_grad &&
1051             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1052           n = e->dst();
1053           break;
1054         }
1055       }
1056     }
1057 
1058     if (n == nullptr) {
1059       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1060               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1061               << "Input node: " << m->DebugString();
1062     }
1063     return n;
1064   }
1065 
1066   // Return a node that can be fused with input node 'n'
1067   //
1068   // @return tuple. If we can find such nodes, the first
1069   // element of the tuple is a true. Otherwise, it's false.
1070   std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1071   CheckForNodeFusion(Node* n) const;
1072 
1073   // Fuse nodes in the vector "nodes"
1074   Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1075                   const MklLayoutRewritePass::FusionInfo fi);
1076 
1077   // Fuse tranpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1078   // mklop("NCHW").
1079   // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1080   static Status FuseTransposeMklOpTranspose(
1081       std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1082       std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1083       string data_format);
1084 
CheckForTranspose(const Node * node,std::vector<int> perm)1085   static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1086     // Check if node's type is "Transpose"
1087     if (node->type_string() != "Transpose") return false;
1088 
1089     // If "Transpose" has multiple output data edges, also don't fuse it.
1090     if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1091 
1092     // Check if has out control edge. If true, this is a training graph.
1093     // Currently we focus on inference and do no fusion in training.
1094     // Note: this constraint will eventually be removed, if we enabled this
1095     // fusion for training
1096     // in the future.
1097     for (const Edge* e : node->out_edges()) {
1098       if (e->IsControlEdge()) {
1099         return false;
1100       }
1101     }
1102 
1103     // If "Transpose" has input control edges, don't fuse on it.
1104     for (const Edge* e : node->in_edges()) {
1105       if (e->IsControlEdge()) {
1106         return false;
1107       }
1108     }
1109 
1110     // We compared the tensor containing the permutation order ("perm_node")
1111     // with our desired order ("perm"). If they're exactly match, this check
1112     // succeed and returns true.
1113     for (const Edge* e : node->in_edges()) {
1114       if (!e->IsControlEdge()) {
1115         const Node* perm_node = e->src();
1116 
1117         const int kPermTensorIndex = 1;
1118         if (perm_node->type_string() == "Const" &&
1119             e->dst_input() == kPermTensorIndex) {
1120           // we find the "perm" node, now try to retrieve its value.
1121           const TensorProto* proto = nullptr;
1122           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1123 
1124           DataType type;
1125           GetNodeAttr(perm_node->def(), "dtype", &type);
1126 
1127           // Here we directly access to the "tensor_content", rather than
1128           // "int_val". This is because we find "int_val" is
1129           // not set properly under some circumstances.
1130           if (type == DT_INT32) {
1131             const int type_size = 4;
1132             const int* tensor_content =
1133                 reinterpret_cast<const int*>(proto->tensor_content().c_str());
1134             const int tensor_content_size =
1135                 proto->tensor_content().size() / type_size;
1136 
1137             std::vector<int> perm_value(tensor_content,
1138                                         tensor_content + tensor_content_size);
1139 
1140             return perm_value == perm;
1141           } else if (type == DT_INT64) {
1142             const int type_size = 8;
1143             const long* tensor_content =
1144                 reinterpret_cast<const long*>(proto->tensor_content().c_str());
1145             const int tensor_content_size =
1146                 proto->tensor_content().size() / type_size;
1147 
1148             std::vector<long> perm_value(tensor_content,
1149                                          tensor_content + tensor_content_size);
1150             std::vector<long> long_perm(perm.cbegin(), perm.cend());
1151 
1152             return perm_value == long_perm;
1153           }
1154           return false;
1155         }
1156       }
1157     }
1158     return false;
1159   }
1160 
CheckForMklOp(const Node * node,string name="")1161   static bool CheckForMklOp(const Node* node, string name = "") {
1162     if (node == nullptr) return false;
1163 
1164     if (!name.empty() && node->type_string() != name) {
1165       return false;
1166     }
1167 
1168     // if mklop has multiple outputs, don't fuse it.
1169     if (node->num_outputs() > 1) return false;
1170 
1171     if (node->out_edges().size() > 1) return false;
1172 
1173     DataType T;
1174     TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1175     return mkl_op_registry::IsMklOp(
1176         mkl_op_registry::GetMklOpName(node->type_string()), T);
1177   }
1178 
1179   // Check if the node 'n' has any applicable rewrite rule
1180   // We check for 2 scenarios for rewrite.
1181   //
1182   // @return RewriteInfo* for the applicable rewrite rule
1183   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1184   const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1185 
1186   // Default rewrite rule to be used in scenario 1 for rewrite.
1187   // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1188   static bool AlwaysRewrite(const Node* n) { return true; }
1189 
1190   // Check if we are performing pooling on depth or batch. If it is, then we
1191   // do not rewrite MaxPool node to Mkl version.
1192   // @return - true (if it is not a depth/batch wise pooling case);
1193   //           false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1194   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1195     CHECK_NOTNULL(n);
1196 
1197     string data_format_str;
1198     TensorFormat data_format;
1199     std::vector<int32> ksize, strides;
1200     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
1201     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
1202     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
1203     CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
1204 
1205     // Condition that specifies non-batch-wise and non-depth-wise pooling.
1206     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1207         GetTensorDim(strides, data_format, 'N') == 1 &&
1208         GetTensorDim(ksize, data_format, 'C') == 1 &&
1209         GetTensorDim(strides, data_format, 'C') == 1) {
1210       return true;
1211     }
1212 
1213     return false;
1214   }
1215 
1216   // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1217   // path. The unoptimized path is slow. Thus we dont rewrite the node
1218   // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1219   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1220   static bool LrnRewrite(const Node* n) {
1221     CHECK_NOTNULL(n);
1222 
1223     int depth_radius;
1224     CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
1225 
1226     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1227     // and use eigen node instead
1228     if (depth_radius == 2) {
1229       return true;
1230     }
1231     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1232             << "case is not optimized by Intel MKL, thus using Eigen op"
1233             << "for LRN ";
1234 
1235     return false;
1236   }
1237 
LrnGradRewrite(const Node * n)1238   static bool LrnGradRewrite(const Node* n) {
1239     CHECK_NOTNULL(n);
1240     bool do_rewrite = false;
1241 
1242     for (const Edge* e : n->in_edges()) {
1243       // Rewrite only if there is corresponding LRN, i.e workspace is available
1244       if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1245           e->src()->type_string() ==
1246               mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1247           e->src_output() == 0) {
1248         do_rewrite = true;
1249         break;
1250       }
1251     }
1252     return do_rewrite;
1253   }
1254 
1255   // MKL-DNN's LeakyRelu(feature) = feature          (if feature > 0), or
1256   //                                feature * alpha  (otherwise),
1257   // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1258   // These two algorithms are not consistent when alpha > 1,
1259   // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1260   static bool LeakyReluRewrite(const Node* n) {
1261     DCHECK(n);
1262 
1263     float alpha;
1264     bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
1265     DCHECK(has_attr);
1266 
1267     // If the alpha of LeakyRelu is less than 1, rewrite the node.
1268     // Otherwise eigen node is used instead.
1269     if (alpha <= 1) {
1270       return true;
1271     }
1272     VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1273             << "which case is not optimized by Intel MKL, thus using Eigen op"
1274             << "for LeakyRelu ";
1275 
1276     return false;
1277   }
1278 
MaxpoolGradRewrite(const Node * n)1279   static bool MaxpoolGradRewrite(const Node* n) {
1280     CHECK_NOTNULL(n);
1281     bool do_rewrite = false;
1282     for (const Edge* e : n->in_edges()) {
1283       // Rewrite only if there is corresponding Maxpool, i.e workspace is
1284       // available
1285       if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1286           e->dst_input() == 1 &&
1287           e->src()->type_string() ==
1288               mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1289           e->src_output() == 0) {
1290         do_rewrite = true;
1291         break;
1292       }
1293     }
1294     return do_rewrite;
1295   }
1296 
AddNRewrite(const Node * n)1297   static bool AddNRewrite(const Node* n) {
1298     CHECK_NOTNULL(n);
1299 
1300     int num;
1301     CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
1302 
1303     // Condition that specifies non-batch-wise and non-depth-wise pooling.
1304     if (num == 2) {
1305       return true;
1306     }
1307 
1308     return false;
1309   }
1310 
FusedConv2DRewrite(const Node * n)1311   static bool FusedConv2DRewrite(const Node* n) {
1312     // MKL DNN currently doesn't support all fusions that grappler fuses
1313     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1314     // it includes those we support.
1315     DataType T;
1316     if (!GetNodeAttr(n->def(), "T", &T).ok() ||
1317         !mkl_op_registry::IsMklOp(csinfo_.mkl_fused_conv2d, T)) {
1318       return false;
1319     }
1320 
1321     std::vector<string> fused_ops;
1322     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1323     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1324             fused_ops == std::vector<string>{"Relu"} ||
1325             fused_ops == std::vector<string>{"BiasAdd", "Relu"});
1326   }
1327 
1328   // Rewrites input node to a new node specified by its matching rewrite info.
1329   //
1330   // Method first searches matching rewrite info for input node and then
1331   // uses that info to rewrite.
1332   //
1333   // Input node may be deleted in case of rewrite. Attempt to use the node
1334   // after the call can result in undefined behaviors.
1335   //
1336   // @input  g - input graph, n - Node to be rewritten,
1337   //         ri - matching rewriteinfo
1338   // @return Status::OK(), if the input node is rewritten;
1339   //         Returns appropriate Status error code otherwise.
1340   //         Graph is updated in case the input node is rewritten.
1341   //         Otherwise, it is not updated.
1342   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1343 
1344   // Get nodes that will feed a list of TF tensors to the new
1345   // node that we are constructing.
1346   //
1347   // @input g - input graph,
1348   // @input inputs - inputs to old node that we are using for constructing
1349   //                 new inputs,
1350   // @input input_idx - the index in the 'inputs' vector pointing to the
1351   //                    current input that we have processed so far
1352   // @output input_idx - index will be incremented by the number of nodes
1353   //                     from 'inputs' that are processed
1354   // @input list_length - The expected length of list of TF tensors
1355   // @output output_nodes - the list of new nodes creating TF tensors
1356   //
1357   // @return None
1358   void GetNodesProducingTFTensorList(
1359       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1360       int* input_idx, int list_length,
1361       std::vector<NodeBuilder::NodeOut>* output_nodes);
1362 
1363   // Get nodes that will feed a list of Mkl tensors to the new
1364   // node that we are constructing.
1365   //
1366   // @input g - input graph,
1367   // @input orig_node - Original node that we are rewriting
1368   // @input inputs - inputs to old node that we are using for constructing
1369   //                 new inputs,
1370   // @input input_idx - the index in the 'inputs' vector pointing to the
1371   //                    current input that we have processed so far
1372   // @output input_idx - index will be incremented by the number of nodes
1373   //                     from 'inputs' that are processed
1374   // @input list_length - The expected length of list of Mkl tensors
1375   // @output output_nodes - the list of new nodes creating Mkl tensors
1376   //
1377   // @return None
1378   void GetNodesProducingMklTensorList(
1379       std::unique_ptr<Graph>* g, Node* orig_node,
1380       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1381       int* input_idx, int list_length,
1382       std::vector<NodeBuilder::NodeOut>* output_nodes);
1383 
1384   // Get a node that will feed an Mkl tensor to the new
1385   // node that we are constructing. The output node could be (1) 'n'
1386   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1387   // if 'n' is not an Mkl layer.
1388   //
1389   // @input g - input graph,
1390   // @input orig_node - Original node that we are rewriting,
1391   // @input n - Node based on which we are creating Mkl node,
1392   // @input n_output_slot - the output slot of node 'n'
1393   //            which is feeding to the node that we are constructing
1394   // @output mkl_node - the new node that will feed Mkl tensor
1395   // @output mkl_node_output_slot - the slot number of mkl_node that
1396   //                                will feed the tensor
1397   // @return None
1398   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
1399                                  Node* n, int n_output_slot, Node** mkl_node,
1400                                  int* mkl_node_output_slot);
1401 
1402   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1403   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1404   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1405   // producing workspace edges if 'are_workspace_tensors_available' is true.
1406   // Otherwise, 'workspace_tensors' is empty vector.
1407   //
1408   // For details, refer to 'Ordering of inputs after rewriting' section in the
1409   // documentation above.
1410   //
1411   // Returns Status::OK() if setting up inputs is successful, otherwise
1412   // returns appropriate status code.
1413   int SetUpContiguousInputs(
1414       std::unique_ptr<Graph>* g,
1415       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1416       NodeBuilder* nb, Node* old_node,
1417       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1418       bool are_workspace_tensors_available);
1419 
1420   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1421   // in graph 'g'. Original node is input in 'orig_node'.
1422   //
1423   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1424   // section in the documentation above.
1425   //
1426   // Returns Status::OK() if setting up inputs is successful, otherwise
1427   // returns appropriate status code.
1428   Status SetUpInputs(std::unique_ptr<Graph>* g,
1429                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1430                      NodeBuilder* nb, Node* orig_node);
1431 
1432   // Add workspace edge on the input or output side of Node 'orig_node' by using
1433   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1434   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1435   // tensors, if they need to be added, will be set into these tensors.
1436   // If we set workspace tensors, then are_ws_tensors_added should be true.
1437   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
1438                                 NodeBuilder* nb,
1439                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1440                                 bool* are_ws_tensors_added);
1441 
1442   // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1443   // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1444   // 'g'. Returns true is fixup was done; otherwise, it returns false.
1445   bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1446                                   const Edge* e_metadata);
1447 
1448   // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1449   // connected? If not, then fix them. This is needed because a graph may have
1450   // some input Mkl metadata edges incorrectly setup after node merge and
1451   // rewrite passes. This could happen because GetReversePostOrder function may
1452   // not provide topologically sorted order if a graph contains cycles. The
1453   // function returns true if at least one Mkl metadata edge for node 'n' was
1454   // fixed. Otherwise, it returns false.
1455   //
1456   // Example:
1457   //
1458   // X = MklConv2D(_, _, _)
1459   // Y = MklConv2DWithBias(_, _, _, _, _, _)
1460   // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
1461   //
1462   // For a graph such as shown above, note that 3rd argument of MklAdd contains
1463   // DummyMklTensor. Actually, it should be getting the Mkl metadata from
1464   // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
1465   // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
1466   // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
1467   // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
1468   // data edges (1st and 2nd arguments of MklAdd).
1469   bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
1470 
1471   // Functions specific to operators to copy attributes
1472   // We need operator-specific function to copy attributes because the framework
1473   // does not provide any generic function for it.
1474   // NOTE: names are alphabetically sorted.
1475   static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
1476                             bool change_format = false);
1477   static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb,
1478                                    bool change_format = false);
1479   static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb,
1480                               bool change_format = false);
1481   static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb,
1482                                 bool change_format = false);
1483   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
1484                             bool change_format = false);
1485   static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb,
1486                                        bool change_format = false);
1487   static void CopyAttrsConv2DDepthwiseCheckConstFilter(
1488       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
1489   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
1490                                             NodeBuilder* nb,
1491                                             bool change_format = false);
1492   static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb,
1493                                 bool change_format = false);
1494   static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
1495                                       bool change_format = false);
1496   static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb,
1497                                  bool change_format = false);
1498   static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
1499                                    bool change_format = false);
1500   static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
1501                            bool change_format = false);
1502   static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
1503                                      bool change_format = false);
1504   static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
1505                                           NodeBuilder* nb,
1506                                           bool change_format = false);
1507   static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
1508                                         const Node* orig_node2, NodeBuilder* nb,
1509                                         bool change_format = false);
1510   static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
1511                                              const Node* orig_node2,
1512                                              NodeBuilder* nb,
1513                                              bool change_format = false);
1514   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
1515                                bool change_format = false);
1516   static void CopyAttrsQuantizedPooling(const Node* orig_node, NodeBuilder* nb,
1517                                         bool change_format = false);
1518   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
1519                                        bool change_format = false);
1520   static void CopyAttrsQuantizedConcat(const Node* orig_node, NodeBuilder* nb,
1521                                        bool change_format = false);
1522   static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb,
1523                                bool change_format = false);
1524   static void CopyAttrsRequantize(const Node* orig_node, NodeBuilder* nb,
1525                                   bool change_format = false);
1526   static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb,
1527                              bool change_format = false);
1528   static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb,
1529                              bool change_format = false);
1530   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
1531                                   const std::vector<int32>& strides,
1532                                   const std::vector<int32>& dilations,
1533                                   bool change_format = false);
1534 
1535   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
1536   // using node for original node 'orig_node' and return it in '*out'.
1537   // TODO(nhasabni) We should move this to mkl_util.h
1538   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
1539                              Node* orig_node);
1540   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
1541                                    Node* orig_node);
1542 };
1543 
1544 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
1545 
1546 // We register Mkl rewrite pass for phase 1 in post partitioning group.
1547 // We register it here so that we get a complete picture of all users of Mkl
1548 // nodes. Do not change the ordering of the Mkl passes.
1549 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
1550     OptimizationPassRegistry::POST_PARTITIONING;
1551 #ifdef ENABLE_MKL
1552 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
1553 #endif  // ENABLE_MKL
1554 
1555 //////////////////////////////////////////////////////////////////////////
1556 //           Helper functions for creating new node
1557 //////////////////////////////////////////////////////////////////////////
1558 
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)1559 static void FillInputs(const Node* n,
1560                        gtl::InlinedVector<Node*, 4>* control_edges,
1561                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
1562   control_edges->clear();
1563   for (const Edge* e : n->in_edges()) {
1564     if (e->IsControlEdge()) {
1565       control_edges->push_back(e->src());
1566     } else {
1567       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
1568     }
1569   }
1570   std::sort(control_edges->begin(), control_edges->end());
1571   if (n->op_def().is_commutative()) {
1572     // For commutative inputs, we sort the input by the input Node*
1573     // to get a canonical ordering (so that add(a,b) and add(b, a) will
1574     // hash to the same value if is_commutative is true for 'add').
1575     std::sort(in->begin(), in->end());
1576   }
1577 }
1578 
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)1579 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
1580     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
1581     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
1582   CHECK_LT(*input_idx, inputs.size());
1583   CHECK_GT(list_length, 0);
1584   CHECK_NOTNULL(output_nodes);
1585   output_nodes->reserve(list_length);
1586 
1587   while (list_length != 0) {
1588     CHECK_GT(list_length, 0);
1589     CHECK_LT(*input_idx, inputs.size());
1590     Node* n = inputs[*input_idx].first;
1591     int slot = inputs[*input_idx].second;
1592     // If input node 'n' is just producing a single tensor at
1593     // output slot 'slot' then we just add that single node.
1594     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
1595     (*input_idx)++;
1596     list_length--;
1597   }
1598 }
1599 
1600 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,Node * orig_node)1601 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
1602                                                  Node** out, Node* orig_node) {
1603   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
1604   // dummy Mkl tensor. 8 = 2*size_t.
1605   const DataType dt = DataTypeToEnum<uint8>::v();
1606   TensorProto proto;
1607   proto.set_dtype(dt);
1608   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
1609   proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
1610   TensorShape dummy_shape({8});
1611   dummy_shape.AsProto(proto.mutable_tensor_shape());
1612   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
1613                   .Attr("value", proto)
1614                   .Attr("dtype", dt)
1615                   .Device(orig_node->def().device())  // We place this node on
1616                                                       // the same device as the
1617                                                       // device of the original
1618                                                       // node.
1619                   .Finalize(&**g, out));
1620   CHECK_NOTNULL(*out);  // Make sure we got a valid object before using it
1621 
1622   // If number of inputs to the original node is > 0, then we add
1623   // control dependency between 1st input (index 0) of the original node and
1624   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
1625   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
1626   // rewritten node. Adding control edge between 1st input of the original node
1627   // and the dummy Mkl node ensures that the dummy node is in the same frame
1628   // as the original node. Choosing 1st input is not necessary - any input of
1629   // the original node is fine because all the inputs of a node are always in
1630   // the same frame.
1631   if (orig_node->num_inputs() > 0) {
1632     Node* orig_input0 = nullptr;
1633     TF_CHECK_OK(
1634         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
1635     // Allow duplicate while adding control edge as it would fail (return
1636     // NULL) if we try to add duplicate edge.
1637     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
1638   }
1639 
1640   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
1641 }
1642 
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)1643 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
1644     std::unique_ptr<Graph>* g, Node* orig_node,
1645     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
1646     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
1647   CHECK_LT(*input_idx, inputs.size());
1648   CHECK_GT(list_length, 0);
1649   CHECK_NOTNULL(output_nodes);
1650   output_nodes->reserve(list_length);
1651 
1652   while (list_length != 0) {
1653     CHECK_GT(list_length, 0);
1654     CHECK_LT(*input_idx, inputs.size());
1655     Node* n = inputs[*input_idx].first;
1656     int slot = inputs[*input_idx].second;
1657     // If 'n' is producing a single tensor, then create a single Mkl tensor
1658     // node.
1659     Node* mkl_node = nullptr;
1660     int mkl_node_output_slot = 0;
1661     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
1662                               &mkl_node_output_slot);
1663     output_nodes->push_back(
1664         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
1665     (*input_idx)++;
1666     list_length--;
1667   }
1668 }
1669 
1670 // Get an input node that will feed Mkl tensor to the new
1671 // node that we are constructing. An input node could be (1) 'n'
1672 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1673 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)1674 void MklLayoutRewritePass::GetNodeProducingMklTensor(
1675     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
1676     Node** mkl_node, int* mkl_node_output_slot) {
1677   CHECK_NOTNULL(n);
1678   CHECK_NOTNULL(mkl_node);
1679   CHECK_NOTNULL(mkl_node_output_slot);
1680 
1681   // If this is an MKL op, then it will create extra output for MKL layout.
1682   DataType T;
1683   if (GetNodeAttr(n->def(), "T", &T).ok() &&
1684       mkl_op_registry::IsMklOp(n->type_string(), T)) {
1685     // If this is an MKL op, then it will generate an edge that will receive
1686     // Mkl tensor from a node.
1687     // output slot number for Mkl tensor would be N+slot number of TensorFlow
1688     // tensor, where N is total number of TensorFlow tensors.
1689     *mkl_node = n;
1690     *mkl_node_output_slot =
1691         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
1692   } else {
1693     // If we have not visited the node and rewritten it, then we need
1694     // to create a dummy node that will feed a dummy Mkl tensor to this node.
1695     // DummyMklTensor node has no input and generates only 1 output
1696     // (dummy Mkl tensor) as output slot number 0.
1697     GetDummyMklTensorNode(g, mkl_node, orig_node);
1698     CHECK_NOTNULL(*mkl_node);
1699     *mkl_node_output_slot = 0;
1700   }
1701 }
1702 
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)1703 int MklLayoutRewritePass::SetUpContiguousInputs(
1704     std::unique_ptr<Graph>* g,
1705     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1706     NodeBuilder* nb, Node* old_node,
1707     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1708     bool are_workspace_tensors_available) {
1709   CHECK_NOTNULL(workspace_tensors);
1710   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
1711 
1712   // TODO(nhasabni): Temporary solution to connect filter input of
1713   // BackpropInput with the converted filter from Conv2D.
1714   bool do_connect_conv2d_backprop_input_filter = false;
1715   Node* conv2d_node = nullptr;
1716   // Filter node is 2nd input (slot index 1) of Conv2D.
1717   int kConv2DFilterInputSlotIdx = 1;
1718   int kConv2DBackpropInputFilterInputSlotIdx = 1;
1719   int kConv2DFilterOutputSlotIdx = 1;
1720   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
1721     // We need to find Conv2D node from Conv2DBackpropInput.
1722     // For that let's first find filter node that is 2nd input (slot 1)
1723     // of BackpropInput.
1724     Node* filter_node = nullptr;
1725     TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
1726                                      &filter_node));
1727     CHECK_NOTNULL(filter_node);
1728 
1729     // Now check which nodes receive from filter_node. Filter feeds as
1730     // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
1731     // _MklFusedConv2D.
1732     for (const Edge* e : filter_node->out_edges()) {
1733       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
1734            e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
1735            e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
1736            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
1737            e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
1738           e->dst_input() == kConv2DFilterInputSlotIdx
1739           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
1740         if (conv2d_node != nullptr) {
1741           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
1742                   << " feeding multiple Conv2D nodes: "
1743                   << filter_node->DebugString();
1744           // We will not connect filter input of Conv2DBackpropInput
1745           // to be safe here.
1746           do_connect_conv2d_backprop_input_filter = false;
1747           break;
1748         } else {
1749           conv2d_node = e->dst();
1750           do_connect_conv2d_backprop_input_filter = true;
1751         }
1752       }
1753     }
1754   }
1755 
1756   // Number of input slots to original op
1757   // Input slots are represented by .Input() calls in REGISTER_OP.
1758   int old_node_input_slots = old_node->op_def().input_arg_size();
1759   // Actual number of inputs can be greater than or equal to number
1760   // of Input slots because inputs of type list could be unfolded.
1761   CHECK_GE(old_node_inputs.size(), old_node_input_slots);
1762   int nn_slot_idx = 0;  // slot index for inputs of new node
1763 
1764   // Let's copy all inputs (TF tensors) of original node to new node.
1765   int iidx = 0;
1766   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
1767     // An input slot could be a single tensor or a list. We need
1768     // to handle this case accordingly.
1769     CHECK_LT(iidx, old_node_inputs.size());
1770     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
1771     if (ArgIsList(arg)) {
1772       std::vector<NodeBuilder::NodeOut> new_node_inputs;
1773       int N = GetTensorListLength(arg, old_node);
1774       GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
1775                                     &new_node_inputs);
1776       nb->Input(new_node_inputs);
1777       nn_slot_idx++;
1778     } else {
1779       // Special case for connecting filter input of Conv2DBackpropInput
1780       if (do_connect_conv2d_backprop_input_filter &&
1781           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
1782         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
1783       } else {
1784         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
1785       }
1786       iidx++;
1787       nn_slot_idx++;
1788     }
1789   }
1790 
1791   // If workspace tensors are available for this op and we are using
1792   // contiguous ordering then we need to add Tensorflow tensor for
1793   // workspace here because Tensorflow tensor for workspace is the
1794   // last tensor in the list of Tensorflow tensors.
1795   if (are_workspace_tensors_available) {
1796     CHECK_EQ(workspace_tensors->size(), 2);
1797     // Tensorflow tensor
1798     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
1799     nn_slot_idx++;
1800   }
1801 
1802   // Let's now setup all Mkl inputs to a new node.
1803   // Number of Mkl inputs must be same as number of TF inputs.
1804   iidx = 0;
1805   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
1806     // An input slot could be a single tensor or a list. We need
1807     // to handle this case accordingly.
1808     CHECK_LT(iidx, old_node_inputs.size());
1809     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
1810     if (ArgIsList(arg)) {
1811       std::vector<NodeBuilder::NodeOut> new_node_inputs;
1812       int N = GetTensorListLength(arg, old_node);
1813       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
1814                                      &new_node_inputs);
1815       nb->Input(new_node_inputs);
1816       nn_slot_idx++;
1817     } else {
1818       Node* mkl_node = nullptr;
1819       int mkl_node_output_slot = 0;
1820       // Special case for connecting filter input of Conv2DBackpropInput
1821       if (do_connect_conv2d_backprop_input_filter &&
1822           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
1823         GetNodeProducingMklTensor(g, old_node, conv2d_node,
1824                                   kConv2DFilterOutputSlotIdx, &mkl_node,
1825                                   &mkl_node_output_slot);
1826       } else {
1827         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
1828                                   old_node_inputs[iidx].second, &mkl_node,
1829                                   &mkl_node_output_slot);
1830       }
1831       nb->Input(mkl_node, mkl_node_output_slot);
1832       iidx++;
1833       nn_slot_idx++;
1834     }
1835   }
1836 
1837   // If workspace tensors are available for this op and we are using
1838   // contiguous ordering then we need to add Mkl tensor for
1839   // workspace here because Mkl tensor for workspace is the
1840   // last tensor in the list of Mkl tensors.
1841   if (are_workspace_tensors_available) {
1842     CHECK_EQ(workspace_tensors->size(), 2);
1843     // Mkl tensor
1844     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
1845     nn_slot_idx++;
1846   }
1847 
1848   return nn_slot_idx;
1849 }
1850 
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,Node * old_node)1851 Status MklLayoutRewritePass::SetUpInputs(
1852     std::unique_ptr<Graph>* g,
1853     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1854     NodeBuilder* nb, Node* old_node) {
1855   // Let's check if we need to add workspace tensors for this node.
1856   // We add workspace edge only for MaxPool, LRN and BatchNorm.
1857   std::vector<NodeBuilder::NodeOut> workspace_tensors;
1858   bool are_workspace_tensors_available = false;
1859 
1860   // Avoid workspace check for QuantizedConv2D and the fused
1861   // Ops as they don't have attribute: "T".
1862   std::vector<string> quant_ops{
1863       "QuantizedConv2D",
1864       "QuantizedConv2DWithBias",
1865       "QuantizedConv2DAndRelu",
1866       "QuantizedConv2DWithBiasAndRelu",
1867       "QuantizedConv2DWithBiasSumAndRelu",
1868       "QuantizedConv2DAndRequantize",
1869       "QuantizedConv2DWithBiasAndRequantize",
1870       "QuantizedConv2DAndReluAndRequantize",
1871       "QuantizedConv2DWithBiasAndReluAndRequantize",
1872       "QuantizedConv2DWithBiasSumAndReluAndRequantize",
1873       "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"};
1874   bool should_check_workspace =
1875       std::find(std::begin(quant_ops), std::end(quant_ops),
1876                 old_node->type_string()) == std::end(quant_ops);
1877   if (should_check_workspace)
1878     AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
1879                              &are_workspace_tensors_available);
1880 
1881   int new_node_input_slots = 0;
1882   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
1883     // TODO(nhasabni): implement this function just for same of completion.
1884     // We do not use interleaved ordering right now.
1885     return Status(
1886         error::Code::UNIMPLEMENTED,
1887         "Interleaved ordering of tensors is currently not supported.");
1888   } else {
1889     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
1890     new_node_input_slots = SetUpContiguousInputs(
1891         g, old_node_inputs, nb, old_node, &workspace_tensors,
1892         are_workspace_tensors_available);
1893   }
1894 
1895   // Sanity check
1896   int old_node_input_slots = old_node->op_def().input_arg_size();
1897   if (!are_workspace_tensors_available) {
1898     // If we are not adding workspace tensors for this op, then the total
1899     // number of input slots to the new node _must_ be 2 times the number
1900     // of input slots to the original node: N original Tensorflow tensors and
1901     // N for Mkl tensors corresponding to each Tensorflow tensors.
1902     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
1903   } else {
1904     // If we are adding workspace tensors for this op, then the total
1905     // The total number of input slots to new node _must_ be 2 times the number
1906     // of input slots to the original node: N original Tensorflow tensors and
1907     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
1908     // (for workspace Tensorflow tensor and workspace Mkl tensor).
1909     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
1910   }
1911 
1912   return Status::OK();
1913 }
1914 
1915 //////////////////////////////////////////////////////////////////////////
1916 //           Helper functions related to workspace pass
1917 //////////////////////////////////////////////////////////////////////////
1918 
1919 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,Node * orig_node)1920 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
1921     std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
1922   // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
1923   // workspace tensor.
1924   GetDummyMklTensorNode(g, out, orig_node);
1925 }
1926 
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)1927 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
1928     std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
1929     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
1930   bool workspace_edge_added = false;  // Default initializer
1931   CHECK_NOTNULL(are_ws_tensors_added);
1932   *are_ws_tensors_added = false;  // Default initializer
1933 
1934   DataType T;
1935   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
1936   for (auto ws : wsinfo_) {
1937     if (orig_node->type_string() == ws.fwd_op &&
1938         mkl_op_registry::IsMklOp(
1939             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
1940       // If this op is a fwd op, then we need to check if there is an
1941       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
1942       // an edge, then we just add an attribute on this node for setting
1943       // workspace_passed to true. We don't add actual workspace edge
1944       // in this node. Actual workspace edge gets added in the backward
1945       // op for this node.
1946       for (const Edge* e : orig_node->out_edges()) {
1947         if (e->src_output() == ws.fwd_slot &&
1948             e->dst()->type_string() == ws.bwd_op &&
1949             e->dst_input() == ws.bwd_slot) {
1950           nb->Attr("workspace_enabled", true);
1951           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
1952                   << orig_node->type_string();
1953           workspace_edge_added = true;
1954           // We found the edge that we were looking for, so break.
1955           break;
1956         }
1957       }
1958 
1959       if (!workspace_edge_added) {
1960         // If we are here, then we did not find backward operator for this
1961         // node.
1962         nb->Attr("workspace_enabled", false);
1963       }
1964     } else if (orig_node->type_string() == ws.bwd_op &&
1965                mkl_op_registry::IsMklOp(
1966                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
1967                    T)) {
1968       // If this op is a bwd op, then we need to add workspace edge and
1969       // it's Mkl tensor edge between its corresponding fwd op and this
1970       // op. Corresponding fwd op is specified in 'fwd_op' field of
1971       // workspace info. fwd_slot and bwd_slot in workspace info specify
1972       // an edge between which slots connect forward and backward op.
1973       // Once all these criteria match, we add a workspace edge between
1974       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
1975       // determined by interleaved/contiguous ordering. Function
1976       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
1977       // from the location of the Tensorflow tensor.
1978       for (const Edge* e : orig_node->in_edges()) {
1979         if (e->src_output() == ws.fwd_slot &&
1980             // We would have rewritten the forward op, so we need to use
1981             // GetMklOpName call to get its Mkl name.
1982             e->src()->type_string() ==
1983                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
1984             e->dst_input() == ws.bwd_slot) {
1985           nb->Attr("workspace_enabled", true);
1986           CHECK_NOTNULL(ws_tensors);
1987           // Add workspace edge between fwd op and bwd op.
1988           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
1989           // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
1990           ws_tensors->push_back(NodeBuilder::NodeOut(
1991               e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
1992                                                  e->src()->num_outputs())));
1993           *are_ws_tensors_added = true;
1994           // In terms of input ordering, we add these calls to add Input
1995           // here because workspace edge (and its Mkl tensor) is the last
1996           // edge in the fwdop and bwdop. So all inputs before workspace
1997           // tensor have been added by SetUpInputs function.
1998           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
1999                   << orig_node->type_string();
2000           workspace_edge_added = true;
2001           // We found the edge that we were looking for, so break.
2002           break;
2003         }
2004       }
2005 
2006       // If we are here means we did not find fwd op that feeds to this
2007       // bwd op. So in this case, we need to generate dummy tensors for
2008       // workspace input and Mkl tensor for workspace, and set
2009       // workspace_enabled to false.
2010       if (!workspace_edge_added) {
2011         nb->Attr("workspace_enabled", false);
2012         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
2013         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
2014         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2015         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2016         CHECK_NOTNULL(dmt_ws);
2017         CHECK_NOTNULL(dmt_mkl_ws);
2018         CHECK_NOTNULL(ws_tensors);
2019         // We add dummy tensor as workspace tensor.
2020         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2021         // We add dummy tensor as Mkl tensor for workspace tensor.
2022         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2023         *are_ws_tensors_added = true;
2024         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2025                 << orig_node->type_string();
2026       }
2027     } else {
2028       // If this node does not match any workspace info, then we do not
2029       // do anything special for workspace propagation for it.
2030     }
2031   }
2032 }
2033 
2034 //////////////////////////////////////////////////////////////////////////
2035 // Op-specific functions to copy attributes from old node to new node
2036 //////////////////////////////////////////////////////////////////////////
2037 
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2038 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2039                                                          NodeBuilder* nb,
2040                                                          bool change_format) {
2041   DataType T;
2042   string padding;
2043   std::vector<int32> strides;
2044   std::vector<int32> dilations;
2045 
2046   // Get all attributes from old node.
2047   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2048   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2049   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2050   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2051 
2052   Node* filter_node = nullptr;
2053   orig_node->input_node(1, &filter_node);
2054 
2055   // Add attributes to new node.
2056   nb->Attr("T", T);
2057   nb->Attr("padding", padding);
2058   nb->Attr("is_filter_const", filter_node->IsConstant());
2059 
2060   // Add attributes related to `data_format`.
2061   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2062 }
2063 
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2064 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2065                                          bool change_format) {
2066   DataType T;
2067   string padding;
2068   std::vector<int32> strides;
2069   std::vector<int32> dilations;
2070 
2071   // Get all attributes from old node.
2072   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2073   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2074   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2075   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2076 
2077   // Add attributes to new node.
2078   nb->Attr("T", T);
2079   nb->Attr("padding", padding);
2080 
2081   // Add attributes related to `data_format`.
2082   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2083 }
2084 
2085 // Used in rinfo when replacing __MklDummyPadWithConv2D by _MklPadWithConv2D
CopyAttrsPadWithConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2086 void MklLayoutRewritePass::CopyAttrsPadWithConv2D(const Node* orig_node,
2087                                                   NodeBuilder* nb,
2088                                                   bool change_format) {
2089   DataType Tpaddings;
2090   DataType T;
2091   string data_format;
2092   string padding;
2093   std::vector<int32> strides;
2094   std::vector<int32> dilations;
2095   bool use_cudnn_on_gpu;
2096 
2097   // Get all attributes from old node.
2098   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2099   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2100   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2101   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2102   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2103   TF_CHECK_OK(
2104       GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2105   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
2106 
2107   Node* filter_node = nullptr;
2108   orig_node->input_node(1, &filter_node);
2109 
2110   // Add attributes to new node.
2111   nb->Attr("T", T);
2112   nb->Attr("strides", strides);
2113   nb->Attr("dilations", dilations);
2114   nb->Attr("padding", padding);
2115   nb->Attr("is_filter_const", filter_node->IsConstant());
2116   nb->Attr("data_format", data_format);
2117   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2118   nb->Attr("Tpaddings", Tpaddings);
2119 }
2120 
CopyAttrsPadWithFusedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2121 void MklLayoutRewritePass::CopyAttrsPadWithFusedConv2D(const Node* orig_node,
2122                                                        NodeBuilder* nb,
2123                                                        bool change_format) {
2124   DataType Tpaddings;
2125 
2126   CopyAttrsFusedConv2D(orig_node, nb, change_format);
2127 
2128   // Get attributes from old node.
2129   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
2130   // Check if filter is a constant.
2131   Node* filter_node = nullptr;
2132   orig_node->input_node(1, &filter_node);
2133 
2134   // Add attributes to new node.
2135   nb->Attr("Tpaddings", Tpaddings);
2136   nb->Attr("is_filter_const", filter_node->IsConstant());
2137 }
2138 
2139 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2140 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2141                                                      const Node* orig_node2,
2142                                                      NodeBuilder* nb,
2143                                                      bool change_format) {
2144   DataType Tpaddings;
2145   DataType T;
2146   string data_format;
2147   string padding;
2148   std::vector<int32> strides;
2149   std::vector<int32> dilations;
2150   bool use_cudnn_on_gpu;
2151 
2152   // Get all attributes from old node 1.
2153   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2154   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2155   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2156   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2157   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2158   TF_CHECK_OK(
2159       GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2160   // Get all attributes from old node 2.
2161   TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2162 
2163   // Add attributes to new node.
2164   nb->Attr("T", T);
2165   nb->Attr("strides", strides);
2166   nb->Attr("dilations", dilations);
2167   nb->Attr("padding", padding);
2168   nb->Attr("data_format", data_format);
2169   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2170   nb->Attr("Tpaddings", Tpaddings);
2171 }
2172 
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2173 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2174     const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2175     bool change_format) {
2176   DataType T;
2177   int num_args;
2178   string data_format;
2179   string padding;
2180   std::vector<int32> strides;
2181   std::vector<int32> dilations;
2182   float epsilon;
2183   std::vector<string> fused_ops;
2184   DataType Tpaddings;
2185 
2186   // Get all attributes from old node.
2187   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2188   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2189   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2190   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2191   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2192   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2193   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2194   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2195   TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2196 
2197   // Add attributes to new node.
2198   nb->Attr("T", T);
2199   nb->Attr("num_args", num_args);
2200   nb->Attr("strides", strides);
2201   nb->Attr("padding", padding);
2202   nb->Attr("data_format", data_format);
2203   nb->Attr("dilations", dilations);
2204   nb->Attr("epsilon", epsilon);
2205   nb->Attr("Tpaddings", Tpaddings);
2206   nb->Attr("fused_ops", fused_ops);
2207 }
2208 
CopyAttrsConv2DDepthwise(const Node * orig_node,NodeBuilder * nb,bool change_format)2209 void MklLayoutRewritePass::CopyAttrsConv2DDepthwise(const Node* orig_node,
2210                                                     NodeBuilder* nb,
2211                                                     bool change_format) {
2212   DataType T;
2213   string data_format;
2214   string padding;
2215   std::vector<int32> strides;
2216   std::vector<int32> dilations;
2217 
2218   // Get all attributes from old node.
2219   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2220   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2221   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2222   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2223   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2224 
2225   // Add attributes to new node.
2226   nb->Attr("T", T);
2227   nb->Attr("strides", strides);
2228   nb->Attr("dilations", dilations);
2229   nb->Attr("padding", padding);
2230   nb->Attr("data_format", data_format);
2231 }
2232 
CopyAttrsConv2DDepthwiseCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2233 void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
2234     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2235   DataType T;
2236   string data_format;
2237   string padding;
2238   std::vector<int32> strides;
2239   std::vector<int32> dilations;
2240 
2241   // Get all attributes from old node.
2242   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2243   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2244   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2245   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2246   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2247 
2248   Node* filter_node = nullptr;
2249   orig_node->input_node(1, &filter_node);
2250 
2251   // Add attributes to new node.
2252   nb->Attr("T", T);
2253   nb->Attr("strides", strides);
2254   nb->Attr("dilations", dilations);
2255   nb->Attr("padding", padding);
2256   nb->Attr("is_filter_const", filter_node->IsConstant());
2257   nb->Attr("data_format", data_format);
2258 }
2259 
CopyAttrsAddN(const Node * orig_node,NodeBuilder * nb,bool change_format)2260 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
2261                                          bool change_format) {
2262   DataType T;
2263   int N;
2264 
2265   // Get all attributes from old node.
2266   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2267   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2268 
2269   // Add attributes to new node.
2270   nb->Attr("T", T);
2271   nb->Attr("N", N);
2272 }
2273 
CopyAttrsBiasAddGrad(const Node * orig_node,NodeBuilder * nb,bool change_format)2274 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
2275                                                 NodeBuilder* nb,
2276                                                 bool change_format) {
2277   DataType T;
2278   string data_format;
2279   std::vector<int32> strides;
2280 
2281   // Get all attributes from old node.
2282   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2283   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2284   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2285 
2286   // Add attributes to new node.
2287   nb->Attr("T", T);
2288   nb->Attr("strides", strides);
2289   nb->Attr("data_format", data_format);
2290 }
2291 
CopyAttrsLRN(const Node * orig_node,NodeBuilder * nb,bool change_format)2292 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
2293                                         bool change_format) {
2294   DataType T;
2295   int depth_radius;
2296   float bias;
2297   float alpha;
2298   float beta;
2299 
2300   // Get all attributes from old node.
2301   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2302   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
2303   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
2304   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
2305   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
2306 
2307   // Add attributes to new node.
2308   nb->Attr("T", T);
2309   nb->Attr("depth_radius", depth_radius);
2310   nb->Attr("bias", bias);
2311   nb->Attr("alpha", alpha);
2312   nb->Attr("beta", beta);
2313 }
2314 
CopyAttrsLeakyRelu(const Node * orig_node,NodeBuilder * nb,bool change_format)2315 void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node,
2316                                               NodeBuilder* nb,
2317                                               bool change_format) {
2318   DataType T;
2319   float alpha;
2320 
2321   // Get all attributes from old node.
2322   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2323   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
2324 
2325   // Add attributes to new node.
2326   nb->Attr("T", T);
2327   nb->Attr("alpha", alpha);
2328 }
2329 
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2330 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2331                                             NodeBuilder* nb,
2332                                             bool change_format) {
2333   DataType T;
2334   string data_format;
2335   string padding;
2336   std::vector<int32> ksize, strides;
2337 
2338   // Get all attributes from old node.
2339   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2340   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2341   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2342   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2343   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2344 
2345   // Add attributes to new node.
2346   nb->Attr("T", T);
2347   nb->Attr("ksize", ksize);
2348   nb->Attr("strides", strides);
2349   nb->Attr("padding", padding);
2350   nb->Attr("data_format", data_format);
2351 }
2352 
CopyAttrsDataType(const Node * orig_node,NodeBuilder * nb,bool change_format)2353 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
2354                                              NodeBuilder* nb,
2355                                              bool change_format) {
2356   DataType T;
2357 
2358   // Get all attributes from old node.
2359   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2360 
2361   // Add attributes to new node.
2362   nb->Attr("T", T);
2363 }
2364 
CopyAttrsQuantizedPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2365 void MklLayoutRewritePass::CopyAttrsQuantizedPooling(const Node* orig_node,
2366                                                      NodeBuilder* nb,
2367                                                      bool change_format) {
2368   DataType T;
2369   string padding;
2370   std::vector<int32> ksize, strides;
2371 
2372   // Get all attributes from old node.
2373   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2374   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2375   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2376   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2377 
2378   // Add attributes to new node.
2379   nb->Attr("T", T);
2380   nb->Attr("ksize", ksize);
2381   nb->Attr("strides", strides);
2382   nb->Attr("padding", padding);
2383 }
2384 
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2385 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2386                                                     NodeBuilder* nb,
2387                                                     bool change_format) {
2388   DataType Tinput, Tfilter, out_type;
2389   string padding;
2390   string data_format("NHWC");
2391   std::vector<int32> strides, dilations, padding_list;
2392   bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2393 
2394   // Get all attributes from old node.
2395   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2396   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2397   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2398   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2399   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2400   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2401   if (has_padding_list) {
2402     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2403   }
2404 
2405   Node* filter_node = nullptr;
2406   orig_node->input_node(1, &filter_node);
2407 
2408   // Add attributes to new node.
2409   nb->Attr("Tinput", Tinput);
2410   nb->Attr("Tfilter", Tfilter);
2411   nb->Attr("out_type", out_type);
2412   nb->Attr("padding", padding);
2413   nb->Attr("is_filter_const", filter_node->IsConstant());
2414   nb->Attr("strides", strides);
2415   nb->Attr("dilations", dilations);
2416   nb->Attr("T", out_type);  // added "T" for facilitating MklToTf conversion.
2417   nb->Attr("data_format", data_format);
2418   if (has_padding_list) {
2419     nb->Attr("padding_list", padding_list);
2420   }
2421 
2422   // Requantization attr Tbias.
2423   DataType Tbias;
2424   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2425   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2426 }
2427 
CopyAttrsRequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2428 void MklLayoutRewritePass::CopyAttrsRequantize(const Node* orig_node,
2429                                                NodeBuilder* nb,
2430                                                bool change_format) {
2431   DataType Tinput, out_type;
2432 
2433   // Get all attributes from old node.
2434   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2435   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2436 
2437   // Add attributes to new node.
2438   nb->Attr("Tinput", Tinput);
2439   nb->Attr("out_type", out_type);
2440 }
2441 
CopyAttrsReshape(const Node * orig_node,NodeBuilder * nb,bool change_format)2442 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
2443                                             NodeBuilder* nb,
2444                                             bool change_format) {
2445   DataType T;
2446   DataType Tshape;
2447 
2448   // Get all attributes from old node.
2449   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2450   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
2451 
2452   // Add attributes to new node.
2453   nb->Attr("T", T);
2454   nb->Attr("Tshape", Tshape);
2455 }
2456 
CopyAttrsSlice(const Node * orig_node,NodeBuilder * nb,bool change_format)2457 void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
2458                                           NodeBuilder* nb, bool change_format) {
2459   DataType T;
2460   DataType Index;
2461 
2462   // Get all attributes from old node.
2463   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2464   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
2465 
2466   // Add attributes to new node.
2467   nb->Attr("T", T);
2468   nb->Attr("Index", Index);
2469 }
2470 
CopyAttrsSplit(const Node * orig_node,NodeBuilder * nb,bool change_format)2471 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
2472                                           NodeBuilder* nb, bool change_format) {
2473   DataType T;
2474   string data_format;
2475   int num_split;
2476 
2477   // Get all attributes from old node.
2478   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2479   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
2480   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2481 
2482   // Add attributes to new node.
2483   nb->Attr("T", T);
2484   nb->Attr("num_split", num_split);
2485   nb->Attr("data_format", data_format);
2486 }
2487 
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2488 void MklLayoutRewritePass::CopyFormatAttrsConv(
2489     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2490     const std::vector<int32>& dilations, bool change_format) {
2491   string data_format;
2492 
2493   if (!change_format) {
2494     nb->Attr("strides", strides);
2495     nb->Attr("dilations", dilations);
2496 
2497     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2498     nb->Attr("data_format", data_format);
2499   } else {
2500     std::vector<int32> new_strides;
2501     std::vector<int32> new_dilations;
2502     if (strides.size() == 5) {
2503       // `strides` and `dilations` also need to be changed according to
2504       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2505       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2506                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2507                      strides[NDHWC::dim::W]};
2508 
2509       new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2510                        dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2511                        dilations[NDHWC::dim::W]};
2512     } else {
2513       // `strides` and `dilations` also need to be changed according to
2514       // `data_format`. In this case, from `NHWC` to `NCHW`.
2515 
2516       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2517                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2518 
2519       new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2520                        dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2521     }
2522     nb->Attr("strides", new_strides);
2523     nb->Attr("dilations", new_dilations);
2524   }
2525 }
2526 
CopyAttrsConcat(const Node * orig_node,NodeBuilder * nb,bool change_format)2527 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
2528                                            NodeBuilder* nb,
2529                                            bool change_format) {
2530   DataType T;
2531   int N;
2532 
2533   // Get all attributes from old node.
2534   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2535   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2536 
2537   // Add attributes to new node.
2538   nb->Attr("T", T);
2539   nb->Attr("N", N);
2540 }
2541 
CopyAttrsConcatV2(const Node * orig_node,NodeBuilder * nb,bool change_format)2542 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
2543                                              NodeBuilder* nb,
2544                                              bool change_format) {
2545   DataType T;
2546   int N;
2547   DataType tidx;
2548 
2549   // Get all attributes from old node.
2550   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2551   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2552   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
2553 
2554   // Add attributes to new node.
2555   nb->Attr("T", T);
2556   nb->Attr("N", N);
2557   nb->Attr("Tidx", tidx);
2558 }
2559 
CopyAttrsFusedBatchNorm(const Node * orig_node,NodeBuilder * nb,bool change_format)2560 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
2561                                                    NodeBuilder* nb,
2562                                                    bool change_format) {
2563   DataType T;
2564   float epsilon;
2565   string data_format;
2566   bool is_training;
2567 
2568   // Get all attributes from old node.
2569   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2570   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
2571   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2572   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
2573 
2574   // Add attributes to new node.
2575   nb->Attr("T", T);
2576   nb->Attr("epsilon", epsilon);
2577   nb->Attr("data_format", data_format);
2578   nb->Attr("is_training", is_training);
2579 }
2580 
CopyAttrsFusedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2581 void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
2582                                                 NodeBuilder* nb,
2583                                                 bool change_format) {
2584   DataType T;
2585   int num_args;
2586   float epsilon;
2587   string data_format;
2588   string padding;
2589   std::vector<int32> strides;
2590   std::vector<int32> dilations;
2591   std::vector<string> fused_ops;
2592 
2593   // Get all attributes from old node.
2594   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2595   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args));
2596   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2597   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2598   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2599   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2600   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
2601   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
2602 
2603   Node* filter_node = nullptr;
2604   orig_node->input_node(1, &filter_node);
2605 
2606   // Add attributes to new node.
2607   nb->Attr("T", T);
2608   nb->Attr("num_args", num_args);
2609   nb->Attr("strides", strides);
2610   nb->Attr("padding", padding);
2611   nb->Attr("is_filter_const", filter_node->IsConstant());
2612   nb->Attr("data_format", data_format);
2613   nb->Attr("dilations", dilations);
2614   nb->Attr("fused_ops", fused_ops);
2615   nb->Attr("epsilon", epsilon);
2616 }
2617 
2618 //////////////////////////////////////////////////////////////////////////
2619 //           Helper functions related to node merge pass
2620 //////////////////////////////////////////////////////////////////////////
2621 
CheckForNodeMerge(const Node * a) const2622 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2623   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2624   // once we support BiasAddGrad as Mkl layer.
2625 
2626   // Search for all matching mergeinfo.
2627   // We allow more than one match for extensibility.
2628   std::vector<const MergeInfo*> matching_mi;
2629   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2630     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2631       matching_mi.push_back(&*mi);
2632     }
2633   }
2634 
2635   for (const MergeInfo* mi : matching_mi) {
2636     // Get the operand with which 'a' can be merged.
2637     Node* b = nullptr;
2638     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2639       continue;
2640     }
2641 
2642     // Get the control edges and input of node
2643     const int N_in = a->num_inputs();
2644     gtl::InlinedVector<Node*, 4> a_control_edges;
2645     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2646     FillInputs(a, &a_control_edges, &a_in);
2647 
2648     const int B_in = b->num_inputs();
2649     gtl::InlinedVector<Node*, 4> b_control_edges;
2650     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2651     FillInputs(b, &b_control_edges, &b_in);
2652 
2653     // Shouldn't merge if a and b have different control edges.
2654     if (a_control_edges != b_control_edges) {
2655       continue;
2656     } else {
2657       // We found a match.
2658       return b;
2659     }
2660   }
2661 
2662   return nullptr;
2663 }
2664 
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2665 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2666                                                     Node* m, Node* n) {
2667   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2668              n->type_string() == csinfo_.conv2d)) ||
2669                ((n->type_string() == csinfo_.bias_add &&
2670                  m->type_string() == csinfo_.conv2d)),
2671            true);
2672 
2673   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2674   // BiasAdd is successor node, and Conv2D predecessor node.
2675   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2676   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2677 
2678   // 1. Get all attributes from input nodes.
2679   DataType T_pred, T_succ;
2680   string padding;
2681   std::vector<int32> strides;
2682   std::vector<int32> dilations;
2683   string data_format_pred, data_format_succ;
2684   bool use_cudnn_on_gpu;
2685   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2686   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2687   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2688   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2689   TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2690   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2691   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2692   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2693   // We check to ensure that data formats of both succ and pred are same.
2694   // We expect them to be same, so we can enforce this as assert.
2695   // But assert can be too strict, so we enforce this as a check.
2696   // If the check fails, then we do not merge two nodes.
2697   // We also do same check for devices.
2698   if (data_format_pred != data_format_succ || T_pred != T_succ ||
2699       pred->assigned_device_name() != succ->assigned_device_name() ||
2700       pred->def().device() != succ->def().device()) {
2701     return Status(error::Code::INVALID_ARGUMENT,
2702                   "data_format or T attribute or devices of Conv2D and "
2703                   "BiasAdd do not match. Will skip node merge optimization");
2704   }
2705 
2706   const int succ_num = succ->num_inputs();
2707   gtl::InlinedVector<Node*, 4> succ_control_edges;
2708   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2709   FillInputs(succ, &succ_control_edges, &succ_in);
2710 
2711   const int pred_num = pred->num_inputs();
2712   gtl::InlinedVector<Node*, 4> pred_control_edges;
2713   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2714   FillInputs(pred, &pred_control_edges, &pred_in);
2715 
2716   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
2717   // not expecting output of Conv2D). If this is not the case, then we cannot
2718   // merge Conv2D with BiasAdd.
2719   const int kFirstOutputSlot = 0;
2720   for (const Edge* e : pred->out_edges()) {
2721     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2722       return Status(error::Code::INVALID_ARGUMENT,
2723                     "Conv2D does not feed to BiasAdd, or "
2724                     "it feeds BiasAdd but has multiple outputs. "
2725                     "Will skip node merge optimization");
2726     }
2727   }
2728 
2729   // 2. Get inputs from both the nodes.
2730   // Find the 2 inputs from the conv and the bias from the add Bias.
2731   // Get operand 0, 1 of conv2D.
2732   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
2733   // Get operand 1 of add_bias
2734   // BiasAdd must have 2 inputs: Conv, bias
2735   CHECK_EQ(succ->in_edges().size(), 2);
2736 
2737   // We will use the node name of BiasAdd as the name of new node
2738   // Build new node. We use same name as original node, but change the op
2739   // name.
2740   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
2741   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
2742   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
2743   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
2744   // In1 of BiasAdd is same as output of Conv2D.
2745   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
2746 
2747   // Copy attributes from Conv2D to Conv2DWithBias.
2748   CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
2749 
2750   // Copy the device assigned to old node to new node.
2751   nb.Device(succ->def().device());
2752 
2753   // Create node.
2754   Node* new_node;
2755   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
2756   CHECK_NOTNULL(new_node);
2757 
2758   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
2759   // node are already copied in BuildNode. We handle control edges now.
2760   for (const Edge* e : pred->in_edges()) {
2761     if (e->IsControlEdge()) {
2762       // Allow duplicate while adding control edge as it would fail (return
2763       // NULL) if we try to add duplicate edge.
2764       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
2765     }
2766   }
2767   for (const Edge* e : succ->in_edges()) {
2768     if (e->IsControlEdge()) {
2769       // Allow duplicate while adding control edge as it would fail (return
2770       // NULL) if we try to add duplicate edge.
2771       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
2772     }
2773   }
2774 
2775   // Incoming edges are fixed, we will fix the outgoing edges now.
2776   // First, we will fix outgoing control edges from 'pred' node.
2777   for (const Edge* e : pred->out_edges()) {
2778     if (e->IsControlEdge()) {
2779       // Allow duplicate while adding control edge as it would fail (return
2780       // NULL) if we try to add duplicate edge.
2781       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
2782     }
2783   }
2784 
2785   // Second, we will fix outgoing control and data edges from 'succ' node.
2786   for (const Edge* e : succ->out_edges()) {
2787     if (e->IsControlEdge()) {
2788       // Allow duplicate while adding control edge as it would fail (return
2789       // NULL) if we try to add duplicate edge.
2790       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
2791     } else {
2792       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
2793       // output (at slot 0).
2794       const int kConv2DWithBiasOutputSlot = 0;
2795       CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
2796                                   e->dst_input()));
2797     }
2798   }
2799 
2800   // Copy device assigned to old node to new node.
2801   // It's ok to use pred or succ as we have enforced a check that
2802   // both have same device assigned.
2803   new_node->set_assigned_device_name(pred->assigned_device_name());
2804 
2805   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
2806           << ", and node: " << succ->DebugString()
2807           << ", into node:" << new_node->DebugString();
2808 
2809   (*g)->RemoveNode(succ);
2810   (*g)->RemoveNode(pred);
2811 
2812   return Status::OK();
2813 }
2814 
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)2815 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
2816                                                 Node* m, Node* n) {
2817   DCHECK((m->type_string() == csinfo_.pad &&
2818           (n->type_string() == csinfo_.conv2d ||
2819            n->type_string() == csinfo_.fused_conv2d)) ||
2820          (n->type_string() == csinfo_.pad &&
2821           (m->type_string() == csinfo_.conv2d ||
2822            m->type_string() == csinfo_.fused_conv2d)));
2823 
2824   bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
2825                          m->type_string() == csinfo_.fused_conv2d;
2826   // Conv2D is successor node, and Pad predecessor node.
2827   Node* pred = m->type_string() == csinfo_.pad ? m : n;
2828   Node* succ = m->type_string() == csinfo_.pad ? n : m;
2829 
2830   // 1. Get all attributes from input nodes.
2831   DataType T_pred, T_succ;
2832   string padding;
2833   std::vector<int32> strides;
2834   std::vector<int32> dilations;
2835   string data_format_pred, data_format_succ;
2836 
2837   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2838   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2839   TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
2840   TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
2841   TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
2842   // Check if the devices of both succ and pred are the same.
2843   // Assert is not used because it can be too strict.
2844   // Don't need to check for data formats because it is not available in Pad.
2845   if (T_pred != T_succ ||
2846       pred->assigned_device_name() != succ->assigned_device_name() ||
2847       pred->def().device() != succ->def().device()) {
2848     return Status(error::Code::INVALID_ARGUMENT,
2849                   "T attribute or devices of Conv2D and "
2850                   "Pad do not match. Will skip node merge optimization");
2851   }
2852 
2853   const int succ_num = succ->num_inputs();
2854   gtl::InlinedVector<Node*, 4> succ_control_edges;
2855   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2856   FillInputs(succ, &succ_control_edges, &succ_in);
2857 
2858   const int pred_num = pred->num_inputs();
2859   gtl::InlinedVector<Node*, 4> pred_control_edges;
2860   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2861   FillInputs(pred, &pred_control_edges, &pred_in);
2862 
2863   // We need to ensure that Pad only feeds to Conv2D (some other operator is
2864   // not expecting output of Pad). If this is not the case, then we cannot
2865   // merge Conv2D with Pad.
2866   const int kFirstOutputSlot = 0;
2867   for (const Edge* e : pred->out_edges()) {
2868     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2869       return Status(error::Code::INVALID_ARGUMENT,
2870                     "Pad does not feed to Conv2D, or "
2871                     "it feeds Conv2D but has multiple outputs. "
2872                     "Will skip node merge optimization");
2873     }
2874   }
2875 
2876   // 2. Get inputs from both the nodes.
2877 
2878   // Pad must have 2 data inputs: "input" and paddings.
2879   int PadDataInputEdges = 0;
2880   for (const Edge* e : pred->in_edges()) {
2881     if (!e->IsControlEdge()) {
2882       PadDataInputEdges++;
2883     }
2884   }
2885   DCHECK_EQ(PadDataInputEdges, 2);
2886 
2887   // Conv2D must have 2 data inputs: Pad output and Filter
2888   // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
2889   int ConvDataInputEdges = 0;
2890   for (const Edge* e : succ->in_edges()) {
2891     if (!e->IsControlEdge()) {
2892       ConvDataInputEdges++;
2893     }
2894   }
2895 
2896   DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
2897 
2898   // We will use the node name of Conv2D as the name of new node
2899   // Build new node. We use same name as original node, but change the op
2900   // name.
2901 
2902   NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
2903                                                : csinfo_.pad_with_conv2d);
2904   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 (input data)  of Pad
2905   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
2906   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 (filter) of conv2d
2907   // In1 of Conv2D is same as output of Pad.
2908   // Thus, only need to add In2 of Conv2D
2909 
2910   if (is_fused_conv2d) {
2911     // FusedConv2D has one additional input, args
2912     std::vector<NodeBuilder::NodeOut> args;
2913     args.emplace_back(succ_in[2].first, succ_in[2].second);
2914     nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
2915         args});                                     // In3 (args) of FusedConv2D
2916     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
2917     // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
2918     CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
2919                                    const_cast<const Node*>(pred), &nb);
2920   } else {
2921     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
2922     // Copy attributes from Pad and conv2D to PadWithConv2D.
2923     CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
2924                               const_cast<const Node*>(pred), &nb);
2925   }
2926 
2927   // Copy the device assigned to old node to new node.
2928   nb.Device(succ->def().device());
2929 
2930   // Create node.
2931   Node* new_node;
2932   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
2933   DCHECK(new_node);
2934 
2935   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
2936   // node are already copied in BuildNode.
2937   // We handle control edges now.
2938   for (const Edge* e : pred->in_edges()) {
2939     if (e->IsControlEdge()) {
2940       // Don't allow duplicate edge
2941       (*g)->AddControlEdge(e->src(), new_node, false);
2942     }
2943   }
2944   for (const Edge* e : succ->in_edges()) {
2945     if (e->IsControlEdge()) {
2946       // Don't allow duplicate edge
2947       (*g)->AddControlEdge(e->src(), new_node, false);
2948     }
2949   }
2950 
2951   // Incoming edges are fixed, we will fix the outgoing edges now.
2952   // First, we will fix outgoing control edges from 'pred' node.
2953   for (const Edge* e : pred->out_edges()) {
2954     if (e->IsControlEdge()) {
2955       // Don't allow duplicate edge
2956       (*g)->AddControlEdge(new_node, e->dst(), false);
2957     }
2958   }
2959 
2960   // Second, we will fix outgoing control and data edges from 'succ' node.
2961   for (const Edge* e : succ->out_edges()) {
2962     if (e->IsControlEdge()) {
2963       // Allow duplicate while adding control edge as it would fail (return
2964       // NULL) if we try to add duplicate edge.
2965       (*g)->AddControlEdge(new_node, e->dst(), false);
2966     } else {
2967       // Conv2D has only 1 output (at slot 0) and merged node also has only 1
2968       // output (at slot 0).
2969       const int kPadWithConv2DOutputSlot = 0;
2970       (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
2971                     e->dst_input());
2972     }
2973   }
2974 
2975   // Copy device assigned to old node to new node.
2976   // It's ok to use pred or succ as we have enforced a check that
2977   // both have same device assigned.
2978   new_node->set_assigned_device_name(pred->assigned_device_name());
2979 
2980   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
2981           << ", and node: " << succ->DebugString()
2982           << ", into node:" << new_node->DebugString();
2983 
2984   (*g)->RemoveNode(succ);
2985   (*g)->RemoveNode(pred);
2986 
2987   return Status::OK();
2988 }
2989 
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)2990 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
2991     std::unique_ptr<Graph>* g, Node* m, Node* n) {
2992   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
2993              n->type_string() == csinfo_.conv2d_grad_filter)) ||
2994                ((n->type_string() == csinfo_.bias_add_grad &&
2995                  m->type_string() == csinfo_.conv2d_grad_filter)),
2996            true);
2997 
2998   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
2999   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3000   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3001 
3002   // Sanity check for attributes from input nodes.
3003   DataType T_b, T_f;
3004   string data_format_b, data_format_f;
3005   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3006   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3007   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3008   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3009   if (data_format_b != data_format_f || T_b != T_f ||
3010       badd->assigned_device_name() != fltr->assigned_device_name() ||
3011       badd->def().device() != fltr->def().device()) {
3012     return Status(error::Code::INVALID_ARGUMENT,
3013                   "data_format or T attribute or devices of "
3014                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
3015                   "Will skip node merge optimization");
3016   }
3017 
3018   // We will use the node name of Conv2DBackpropFilter as the name of new node.
3019   // This is because BackpropFilterWithBias is going to emit bias output also.
3020   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3021   // Since Conv2DBackpropFilterWithBias has same number of inputs as
3022   // Conv2DBackpropFilter, we can just copy input edges directly. We dont need
3023   // to copy any data input of BiasAddGrad because that input also goes to
3024   // Conv2DBackpropFilter.
3025   const int fltr_ins = fltr->num_inputs();
3026   gtl::InlinedVector<Node*, 4> fltr_control_edges;
3027   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3028   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3029   for (int idx = 0; idx < fltr_ins; idx++) {
3030     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3031   }
3032 
3033   // Copy attributes from Conv2DBackpropFilter.
3034   CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3035 
3036   // Copy the device assigned to old node to new node.
3037   nb.Device(fltr->def().device());
3038 
3039   // Create node.
3040   Node* new_node;
3041   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3042   CHECK_NOTNULL(new_node);
3043 
3044   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3045   // new 'new_node' node are already copied in BuildNode. We handle control
3046   // edges now.
3047   for (const Edge* e : badd->in_edges()) {
3048     if (e->IsControlEdge()) {
3049       // Allow duplicate while adding control edge as it would fail (return
3050       // NULL) if we try to add duplicate edge.
3051       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3052     }
3053   }
3054   for (const Edge* e : fltr->in_edges()) {
3055     if (e->IsControlEdge()) {
3056       // Allow duplicate while adding control edge as it would fail (return
3057       // NULL) if we try to add duplicate edge.
3058       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3059     }
3060   }
3061 
3062   // Incoming edges are fixed, we will fix the outgoing edges now.
3063   // First, we will fix outgoing control edges from 'badd' node.
3064   // Conv2DBackpropFilter has 1 output -- filter_grad.
3065   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3066   // bias_grad. But filter_grad is at same slot number (0) in both the
3067   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3068   // it is at slot number 0 in BiasAddGrad.
3069   const int kMergedNodeFilterGradOutputIdx = 0;
3070   const int kMergedNodeBiasGradOutputIdx = 1;
3071 
3072   for (const Edge* e : badd->out_edges()) {
3073     if (e->IsControlEdge()) {
3074       // Allow duplicate while adding control edge as it would fail (return
3075       // NULL) if we try to add duplicate edge.
3076       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3077     } else {
3078       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3079                                   e->dst(), e->dst_input()));
3080     }
3081   }
3082 
3083   // Second, we will fix outgoing control and data edges from 'fltr' node.
3084   for (const Edge* e : fltr->out_edges()) {
3085     if (e->IsControlEdge()) {
3086       // We allow duplicate edge for this case since we already add control
3087       // edge from new_node in line 3990. Line below could be adding same
3088       // edge to same destination again. In such case, if we do not allow
3089       // duplicate edge, then this call will fail.
3090       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3091     } else {
3092       CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3093                                   e->dst(), e->dst_input()));
3094     }
3095   }
3096 
3097   // Copy device assigned to old node to new node.
3098   // It's ok to use badd or fltr as we have enforced a check that
3099   // both have same device assigned.
3100   new_node->set_assigned_device_name(badd->assigned_device_name());
3101 
3102   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3103           << ", and node: " << fltr->DebugString()
3104           << ", into node:" << new_node->DebugString();
3105 
3106   (*g)->RemoveNode(badd);
3107   (*g)->RemoveNode(fltr);
3108 
3109   return Status::OK();
3110 }
3111 
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3112 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3113                                        Node* n) {
3114   CHECK_NOTNULL(m);
3115   CHECK_NOTNULL(n);
3116 
3117   if (((m->type_string() == csinfo_.bias_add &&
3118         n->type_string() == csinfo_.conv2d)) ||
3119       ((n->type_string() == csinfo_.bias_add &&
3120         m->type_string() == csinfo_.conv2d))) {
3121     return this->MergeConv2DWithBiasAdd(g, m, n);
3122   }
3123   if ((m->type_string() == csinfo_.pad &&
3124        (n->type_string() == csinfo_.conv2d ||
3125         (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3126       (n->type_string() == csinfo_.pad &&
3127        (m->type_string() == csinfo_.conv2d ||
3128         (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3129     return this->MergePadWithConv2D(g, m, n);
3130   }
3131 
3132   if (((m->type_string() == csinfo_.bias_add_grad &&
3133         n->type_string() == csinfo_.conv2d_grad_filter)) ||
3134       ((n->type_string() == csinfo_.bias_add_grad &&
3135         m->type_string() == csinfo_.conv2d_grad_filter))) {
3136     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3137   }
3138 
3139   return Status(error::Code::UNIMPLEMENTED,
3140                 "Unimplemented case for node merge optimization.");
3141 }
3142 
3143 //////////////////////////////////////////////////////////////////////////
3144 //           Helper functions for node rewrite
3145 //////////////////////////////////////////////////////////////////////////
3146 
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3147 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3148                                          Node* orig_node,
3149                                          const RewriteInfo* ri) {
3150   CHECK_NOTNULL(ri);
3151   CHECK_NOTNULL(orig_node);
3152 
3153   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3154 
3155   // Get all inputs.
3156   int num_inputs = orig_node->in_edges().size();
3157 
3158   // Drop count for control edges from inputs
3159   for (const Edge* e : orig_node->in_edges()) {
3160     if (e->IsControlEdge()) {
3161       num_inputs--;
3162     }
3163   }
3164 
3165   gtl::InlinedVector<Node*, 4> control_edges;
3166   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
3167   FillInputs(orig_node, &control_edges, &inputs);
3168 
3169   // Build new node. We use same name as original node, but change the op name.
3170   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3171   // Copy user-specified device assigned to original node to new node.
3172   nb.Device(orig_node->def().device());
3173   // Set up new inputs to the rewritten node.
3174   Status s = SetUpInputs(g, inputs, &nb, orig_node);
3175   if (s != Status::OK()) {
3176     return s;
3177   }
3178 
3179   const bool kPartialCopyAttrs = false;
3180   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3181 
3182   // Set the Mkl layer label for this op.
3183   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3184       DataTypeIsQuantized(orig_node->output_type(0))) {
3185     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3186   } else {
3187     nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
3188   }
3189   // Finalize graph and get new node.
3190   Node* new_node = nullptr;
3191   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3192   CHECK_NOTNULL(new_node);
3193 
3194   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3195   // already copied in BuildNode. We need to handle control edges now.
3196   for (const Edge* e : orig_node->in_edges()) {
3197     if (e->IsControlEdge()) {
3198       // Allow duplicate while adding control edge as it would fail (return
3199       // NULL) if we try to add duplicate edge.
3200       CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3201     }
3202   }
3203 
3204   // Copy outgoing edges from 'orig_node' node to new
3205   // 'new_node' node, since the output also follows same ordering among
3206   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3207   // tensors appropriately. Specifically, nth output of the original node
3208   // will become 2*nth output of the Mkl node for the interleaved ordering
3209   // of the tensors. For the contiguous ordering of the tensors, it will be n.
3210   // GetTensorDataIndex provides this mapping function.
3211   for (const Edge* e : orig_node->out_edges()) {
3212     if (e->IsControlEdge()) {
3213       // Allow duplicate while adding control edge as it would fail (return
3214       // NULL) if we try to add duplicate edge.
3215       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3216     } else {
3217       CHECK_NOTNULL((*g)->AddEdge(
3218           new_node,
3219           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3220           e->dst(), e->dst_input()));
3221     }
3222   }
3223 
3224   // Copy the runtime device assigned from original code to new node.
3225   new_node->set_assigned_device_name(orig_node->assigned_device_name());
3226 
3227   // Delete original node and mark new node as rewritten.
3228   (*g)->RemoveNode(orig_node);
3229 
3230   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3231   return Status::OK();
3232 }
3233 
3234 // TODO(mdfaijul): Is there any other elegent way to check for quantized ops
3235 // having attributes other than "T"?
3236 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3237 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3238 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3239   DataType Tinput, Tfilter;
3240   if (!(GetNodeAttr(n->def(), "Tinput", &Tinput).ok() &&
3241         GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok())) {
3242     return nullptr;
3243   }
3244   if (mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3245                                Tinput, Tfilter)) {
3246     for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3247       if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3248         return &*ri;
3249       }
3250     }
3251   }
3252   return nullptr;
3253 }
3254 
3255 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3256 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3257   CHECK_NOTNULL(n);
3258 
3259   // QuntizedOps may have attributes other than "T", so decoupled the check
3260   // with a function, CheckForQuantizedNodeRewrite(const Node*).
3261   const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3262   if (ri != nullptr) return ri;
3263 
3264   // First check if node along with its type is supported by MKL layer.
3265   // We do not want to rewrite an op into Mkl op if types are not supported.
3266   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3267   // MklRelu if type is INT32.
3268   DataType T;
3269   if (!GetNodeAttr(n->def(), "T", &T).ok()) {
3270     return nullptr;
3271   }
3272 
3273   // We make an exception for __MklDummyConv2DWithBias,
3274   // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3275   // names do not match Mkl node names.
3276   if (n->type_string() != csinfo_.conv2d_with_bias &&
3277       n->type_string() != csinfo_.pad_with_conv2d &&
3278       n->type_string() != csinfo_.pad_with_fused_conv2d &&
3279       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3280       n->type_string() != csinfo_.fused_conv2d &&
3281       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3282                                 T)) {
3283     return nullptr;
3284   }
3285 
3286   // For elementwise node, we reuse the Eigen implementation and pass the MKL
3287   // metadata tensor through so we can avoid conversions. However, if all
3288   // incoming edges are in TF format, we don't need all this overhead, so
3289   // replace the elementwise node only if at least one of its parents is a MKL
3290   // node.
3291   //
3292   // Identity nodes can also skip replacement if they are not being served by
3293   // any MKL nodes.
3294   //
3295   // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
3296   // eigen code to reduce cross-library dependency.
3297   VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
3298   if (mkl_op_registry::IsMklElementWiseOp(
3299           mkl_op_registry::GetMklOpName(n->type_string()), T) ||
3300       n->type_string().find("Identity") != string::npos) {
3301     VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
3302     bool incoming_mkl_edge = false;
3303     int num_parent = 0;
3304     for (auto parent : n->in_edges()) {
3305       if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
3306         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
3307                 << " is MKL op: " << parent->src()->type_string();
3308         incoming_mkl_edge = true;
3309         break;
3310       } else {
3311         VLOG(1) << "ELEMENTWISE: parent " << num_parent++
3312                 << " is NON-MKL op: " << parent->src()->type_string();
3313       }
3314     }
3315     if (incoming_mkl_edge == false) {
3316       VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
3317                  "has no MKL "
3318                  "parents.";
3319       return nullptr;
3320     } else {
3321       VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
3322               << " which has MKL parents";
3323     }
3324   }
3325 
3326   // We now check if rewrite rule applies for this op. If rewrite rule passes
3327   // for this op, then we rewrite it to Mkl op.
3328   // Find matching RewriteInfo and then check that rewrite rule applies.
3329   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3330     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3331       return &*ri;
3332     }
3333   }
3334 
3335   // Else return not found.
3336   return nullptr;
3337 }
3338 
3339 //////////////////////////////////////////////////////////////////////////
3340 //           Helper functions for node fusion
3341 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3342 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3343     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3344     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3345     string data_format) {
3346   Node* transpose_to_nhwc = nodes[0];
3347   Node* mklop = nodes[1];
3348   Node* transpose_to_nchw = nodes[2];
3349 
3350   const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3351   gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3352   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3353       transpose_nhwc_num_inputs);
3354   FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3355              &transpose_nhwc_in);
3356 
3357   const int mklop_num_inputs = mklop->num_inputs();
3358   gtl::InlinedVector<Node*, 4> mklop_control_edges;
3359   gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3360   FillInputs(mklop, &mklop_control_edges, &mklop_in);
3361 
3362   const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3363   gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3364   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3365       transpose_nchw_num_inputs);
3366   FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3367              &transpose_nchw_in);
3368 
3369   // We use same name as original node, but change the op
3370   // type.
3371   NodeBuilder nb(mklop->name(), mklop->type_string());
3372 
3373   // Storing the output slots of the input nodes.
3374   for (int i = 0; i < mklop_num_inputs; i++) {
3375     if (mklop_in[i].first == transpose_to_nhwc) {
3376       // Fill "x":
3377       nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3378     } else {
3379       // Fill inputs other than "x":
3380       nb.Input(mklop_in[i].first, mklop_in[i].second);
3381     }
3382   }
3383 
3384   copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3385   nb.Attr("data_format", data_format);
3386 
3387   // Copy the device assigned to old node to new node.
3388   nb.Device(mklop->def().device());
3389 
3390   // Create node.
3391   Node* new_node;
3392   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3393   DCHECK(new_node);
3394 
3395   // Fill outputs.
3396   for (const Edge* e : transpose_to_nchw->out_edges()) {
3397     if (!e->IsControlEdge()) {
3398       const int kTransposeWithMklOpOutputSlot = 0;
3399       auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3400                                     e->dst(), e->dst_input());
3401       DCHECK(new_edge);
3402     }
3403   }
3404 
3405   // Copy device assigned to old node to new node.
3406   new_node->set_assigned_device_name(mklop->assigned_device_name());
3407 
3408   // Copy requested_device and assigned_device_name_index
3409   new_node->set_requested_device(mklop->requested_device());
3410   new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3411 
3412   (*g)->RemoveNode(transpose_to_nhwc);
3413   (*g)->RemoveNode(mklop);
3414   (*g)->RemoveNode(transpose_to_nchw);
3415 
3416   return Status::OK();
3417 }
3418 
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3419 Status MklLayoutRewritePass::FuseNode(
3420     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3421     const MklLayoutRewritePass::FusionInfo fi) {
3422   return fi.fuse_func(g, nodes, fi.copy_attrs);
3423 }
3424 
3425 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3426 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3427   // Stores matched nodes, in the same order as node_checkers.
3428   std::vector<Node*> nodes;
3429 
3430   for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3431     //
3432     // Make sure node "a" and its succeding nodes (b, c ...), match the pattern
3433     // defined in fusion info (ops[0], ops[1], ...),
3434     // a.k.a. "a->b->c" matches "op1->op2->op3"
3435     //
3436 
3437     // Stores the first unvisted outgoing edge of each matched node in "nodes".
3438     std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3439     nodes.clear();
3440 
3441     auto node_checker = fi->node_checkers.begin();
3442     if (a != nullptr && (*node_checker)(a)) {
3443       nodes.push_back(a);
3444       current_neighbor_stack.push(a->out_edges().begin());
3445       ++node_checker;
3446     }
3447 
3448     while (!nodes.empty()) {
3449       auto& current_neighbor_iter = current_neighbor_stack.top();
3450 
3451       if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3452         // Found an unvisited edge. Goes through the edge to get the neighbor.
3453         Node* neighbor_node = (*current_neighbor_iter)->dst();
3454         ++current_neighbor_stack.top();  // Retrieves the next unvisited edge.
3455 
3456         if ((*node_checker)(neighbor_node)) {
3457           // Found a match. Stores the node and moves to the next checker.
3458           nodes.push_back(neighbor_node);
3459           current_neighbor_stack.push(neighbor_node->out_edges().begin());
3460           if (++node_checker == fi->node_checkers.end()) {
3461             return make_tuple(true, nodes, *fi);
3462           }
3463         }
3464       } else {
3465         // Removes the current node since none of its neighbor leads to a
3466         // further match.
3467         nodes.pop_back();
3468         current_neighbor_stack.pop();
3469         --node_checker;
3470       }
3471     }
3472   }
3473 
3474   return make_tuple(false, std::vector<Node*>(), FusionInfo());
3475 }
3476 
3477 ///////////////////////////////////////////////////////////////////////////////
3478 //              Post-rewrite Mkl metadata fixup pass
3479 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3480 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3481                                                       const Edge* e_data,
3482                                                       const Edge* e_metadata) {
3483   if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3484     return false;
3485   }
3486 
3487   Node* n_data = e_data->src();
3488   int n_data_op_slot = e_data->src_output();
3489   int n_metadata_op_slot =
3490       GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3491 
3492   // If the source of meta edge is a constant node (producing dummy Mkl metadata
3493   // tensor), then we will need to fix.
3494   if (IsConstant(e_metadata->src())) {
3495     Node* e_metadata_dst = e_metadata->dst();
3496     int e_metadata_in_slot = e_metadata->dst_input();
3497     CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3498                                 e_metadata_in_slot));
3499 
3500     (*g)->RemoveEdge(e_metadata);
3501     return true;
3502   }
3503 
3504   return false;
3505 }
3506 
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3507 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3508                                                Node* n) {
3509   bool result = false;
3510 
3511   // If graph node is not Mkl node, then return.
3512   DataType T = DT_INVALID;
3513   if (!GetNodeAttr(n->def(), "T", &T).ok() ||
3514       !mkl_op_registry::IsMklOp(n->type_string(), T)) {
3515     return result;
3516   }
3517 
3518   // If it is Mkl node, then check if the input edges to this node that carry
3519   // Mkl metadata are linked up correctly with the source node.
3520 
3521   // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3522   // data tensors + n for Mkl metadata tensors). We need to check for correct
3523   // connection of n metadata tensors only.
3524   int num_data_inputs = n->num_inputs() / 2;
3525   for (int idx = 0; idx < num_data_inputs; idx++) {
3526     // Get the edge connecting input slot with index (idx).
3527     const Edge* e = nullptr;
3528     TF_CHECK_OK(n->input_edge(idx, &e));
3529 
3530     // If e is control edge, then skip.
3531     if (e->IsControlEdge()) {
3532       continue;
3533     }
3534 
3535     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3536     // node, then we don't need to do anything.
3537     Node* e_src = e->src();
3538     if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
3539         mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
3540       // Source node for edge 'e' is Mkl node.
3541       // Destination node and destination input slot of e is node 'n' and 'idx'
3542       // resp.
3543       CHECK_EQ(e->dst(), n);
3544       CHECK_EQ(e->dst_input(), idx);
3545 
3546       // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3547       // 'e'. For that, let's first get the input slot of 'n' where the meta
3548       // edge will feed the value.
3549       int e_meta_in_slot =
3550           GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3551       const Edge* e_meta = nullptr;
3552       TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3553 
3554       // Let's check if we need to fix this meta edge.
3555       if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3556         result = true;
3557       }
3558     }
3559   }
3560 
3561   return result;
3562 }
3563 
3564 ///////////////////////////////////////////////////////////////////////////////
3565 //              Run function for the pass
3566 ///////////////////////////////////////////////////////////////////////////////
3567 
RunPass(std::unique_ptr<Graph> * g)3568 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
3569   bool result = false;
3570   CHECK_NOTNULL(g);
3571 
3572   DumpGraph("Before running MklLayoutRewritePass", &**g);
3573 
3574   std::vector<Node*> order;
3575   GetReversePostOrder(**g, &order);  // This will give us topological sort.
3576   for (Node* n : order) {
3577     // If node is not an op or it cannot run on CPU device, then skip.
3578     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3579       continue;
3580     }
3581 
3582     Node* m = nullptr;
3583     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
3584       // Check if the node 'n' can be merged with any other node. If it can
3585       // be 'm' contains the node with which it can be merged.
3586       string n1_name = n->name();
3587       string n2_name = m->name();
3588 
3589       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
3590               << n2_name << " for merging";
3591 
3592       if (MergeNode(g, n, m) == Status::OK()) {
3593         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
3594                 << n2_name;
3595         result = true;
3596       }
3597     }
3598   }
3599 
3600   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
3601 
3602   order.clear();
3603   GetReversePostOrder(**g, &order);  // This will give us topological sort.
3604   for (Node* n : order) {
3605     // If node is not an op or it cannot run on CPU device, then skip.
3606     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3607       continue;
3608     }
3609 
3610     auto check_result = CheckForNodeFusion(n);
3611     bool found_pattern = std::get<0>(check_result);
3612     std::vector<Node*> nodes = std::get<1>(check_result);
3613     const FusionInfo fi = std::get<2>(check_result);
3614 
3615     // if "found_pattern" is true, we can do the fusion.
3616     if (found_pattern) {
3617       if (FuseNode(g, nodes, fi) == Status::OK()) {
3618         result = true;
3619       }
3620     }
3621   }
3622   DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
3623 
3624   order.clear();
3625   GetReversePostOrder(**g, &order);  // This will give us topological sort.
3626   for (Node* n : order) {
3627     // If node is not an op or it cannot run on CPU device, then skip.
3628     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3629       continue;
3630     }
3631 
3632     const RewriteInfo* ri = nullptr;
3633     // We will first search if node is to be rewritten.
3634     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
3635       string node_name = n->name();
3636       string op_name = n->type_string();
3637 
3638       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
3639               << " with op " << op_name << " for rewrite using"
3640               << " layout optimization.";
3641 
3642       if (RewriteNode(g, n, ri) == Status::OK()) {
3643         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
3644                 << " with op " << op_name << " for Mkl layout optimization.";
3645         result = true;
3646       }
3647     }
3648   }
3649 
3650   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
3651 
3652   order.clear();
3653   GetReversePostOrder(**g, &order);  // This will give us topological sort.
3654   for (Node* n : order) {
3655     // If node is not an op or it cannot run on CPU device, then skip.
3656     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3657       continue;
3658     }
3659     if (FixMklMetaDataEdges(g, n)) {
3660       string node_name = n->name();
3661       string op_name = n->type_string();
3662 
3663       VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
3664               << node_name << " with op " << op_name;
3665       result = true;
3666     }
3667   }
3668   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
3669             &**g);
3670 
3671   return result;
3672 }
3673 
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)3674 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
3675   return MklLayoutRewritePass().RunPass(g);
3676 }
3677 
Run(const GraphOptimizationPassOptions & options)3678 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
3679   if (options.graph == nullptr && options.partition_graphs == nullptr) {
3680     return Status::OK();
3681   }
3682   if (DisableMKL()) {
3683     VLOG(2) << "TF-MKL: Disabling MKL";
3684     return Status::OK();
3685   }
3686 
3687   auto process_graph = [&](std::unique_ptr<Graph>* g) {
3688     // Get the ownership of a graph
3689     std::unique_ptr<Graph>* ng = std::move(g);
3690     RunPass(ng);
3691     // Return the ownership of a graph back
3692     g->reset(ng->release());
3693   };
3694 
3695   if (kMklLayoutRewritePassGroup !=
3696       OptimizationPassRegistry::POST_PARTITIONING) {
3697     // For any pre-partitioning phase, a graph is stored in options.graph.
3698     process_graph(options.graph);
3699   } else {
3700     // For post partitioning phase, graphs are stored in
3701     // options.partition_graphs.
3702     for (auto& pg : *options.partition_graphs) {
3703       process_graph(&pg.second);
3704     }
3705   }
3706 
3707   return Status::OK();
3708 }
3709 
3710 }  // namespace tensorflow
3711 
3712 #endif
3713