1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19 
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21 
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32 
33 #include "tensorflow/core/common_runtime/function.h"
34 #include "tensorflow/core/common_runtime/optimization_registry.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/mkl_graph_util.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/lib/gtl/map_util.h"
44 #include "tensorflow/core/lib/hash/hash.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/util/tensor_format.h"
47 #include "tensorflow/core/util/util.h"
48 
49 namespace tensorflow {
50 
51 // This pass implements rewriting of graph to support following scenarios:
52 // (A) Merging nodes in the graph
53 // (B) Rewriting a node in the graph to a new node
54 //     Rewrite happens under following scenario:
55 //     - Propagating Mkl layout as an additional output tensor
56 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
57 //         henceforth.) from every Mkl supported NN layer.
58 //
59 // Example of A : Merging nodes in the graph
60 // -----------------------------------------
61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
62 //
63 //           O = Conv2D(A, B)
64 //           P = BiasAdd(O, C)
65 //
66 // We merge them into Conv2DWithBias as:
67 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
68 //
69 // The meaning of A_m, B_m and C_m is explained in B.1.
70 //
71 // Merge rules:
72 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
73 //    goes to BiasAdd.
74 //  - Also, the intersection of attributes of both the nodes must have same
75 //    values.
76 //  - Both the nodes must have been assigned to same device (if any).
77 //
78 // Example of B.1 : Rewriting nodes to Mkl nodes
79 // ---------------------------------------------
80 // Consider a Relu node. Current definition of Relu node looks like:
81 //
82 //           O = Relu(A)
83 //
84 // Relu has 1 input (A), and 1 output (O).
85 //
86 // This rewrite pass will generate a new graph node for Relu (new node is
87 // called MklRelu) as:
88 //
89 //          O, O_m = MklRelu(A, A_m)
90 //
91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
92 // same as input A of Relu; output O is same as output O of Relu. O_m is the
93 // additional output tensor that will be set by MklRelu, and it represents
94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
95 // metadata for O. A_m is additional input of Relu, and it represents metadata
96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
97 // this metadata from previous node in the graph.
98 //
99 // When a previous node in the graph is an Mkl node, A_m will represent a valid
100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
101 // a dummy Mkl tensor.
102 //
103 // Rewriting rules:
104 //  - Selection of a node for rewriting happens by registering the op type of
105 //    the node with the rewriting pass. If the op type is not registered, then
106 //    all nodes of this op type will not be rewritten.
107 //  - Number of inputs after rewriting:
108 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
109 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
110 //      inputs for the original node.
111 //  - Number of outputs after rewriting:
112 //      Since for every output Tensorflow tensor, the rewritten node generates
113 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
114 //      number of outputs of the original node.
115 //  - Ordering of Tensorflow tensors and Mkl tensors:
116 //      Since every rewritten node generates twice the number of inputs and
117 //      outputs, one could imagine various orderings among Tensorflow tensors
118 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
119 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
120 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
121 //      order. Among N inputs one can get N! permutations.
122 //
123 //      So the question is: which order do we follow? We support 2 types of
124 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
125 //      follows an intuitive order where an Mkl tensor follows the
126 //      corresponding Tensorflow tensor immediately. In the context of the
127 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
128 //      applies to both the inputs and outputs. Contiguous ordering means
129 //      all the Tensorflow tensors are contiguous followed by all the Mkl
130 //      tensors. We use contiguous ordering as default.
131 //
132 // Graph rewrite algorithm:
133 //      Algorithm: Graph Rewrite
134 //      Input: Graph G, Names of the nodes to rewrite and their new names
135 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
136 //      Start:
137 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
138 //        foreach node n in N
139 //        do
140 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
141 //          then
142 //            E = set of <incoming edge and its src_output slot> of n
143 //            E' = {}   // a new set of edges for rewritten node
144 //            foreach <e,s> in E
145 //            do
146 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
147 //                            // tensor as it is
148 //              m = Source node of edge e
149 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
150 //              then
151 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
152 //                                  // tensor as an additional output.
153 //              else
154 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
155 //                                                 // Mkl tensor.
156 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
157 //              fi
158 //            done
159 //            n' = Build_New_Node(G,new_name,E')
160 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
161 //          fi
162 //        done
163 //
164 //      Explanation:
165 //        For graph rewrite, we visit nodes of the input graph in the
166 //        topological sort order. With this ordering, we visit nodes in the
167 //        top-to-bottom fashion. We need this order because while visiting a
168 //        node we want that all of its input nodes are visited and rewritten if
169 //        applicable. This is because if we need to rewrite a given node
170 //        then all of its input nodes need to be fixed (in other words they
171 //        cannot be deleted later.)
172 //
173 //        While visiting a node, we first check if the op type of the node is
174 //        an Mkl op. If it is, then we rewrite that node after constructing
175 //        new inputs to the node. If the op type of the node is not Mkl op,
176 //        then we do not rewrite that node.
177 //
178 // Handling workspace propagation for certain ops:
179 //
180 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
181 //        passing of a workspace from their respective forward ops. Workspace
182 //        tensors provide memory for storing results of intermediate operations
183 //        which are helpful in backward propagation. TensorFlow does not have
184 //        a notion of a workspace and as a result does not allow producing
185 //        additional outputs from these forward ops. For these ops, we need
186 //        to add 2 extra edges between forward ops and their corresponding
187 //        backward ops - the first extra edge carries a workspace tensor and
188 //        the second one carries an Mkl tensor for the workspace tensor.
189 //
190 //        Example:
191 //
192 //        Typical graph for MaxPool and its gradient looks like:
193 //
194 //        A = MaxPool(T)
195 //        B = MaxPoolGrad(X, A, Y)
196 //
197 //        We will transform this graph to propagate the workspace as:
198 //        (with the contiguous ordering)
199 //
200 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
201 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
202 //
203 //        Here W is the workspace tensor. Transformed tensor names with the
204 //        suffix _m are Mkl tensors, and this transformation has been done
205 //        using the algorithm discussed earlier. The transformation for
206 //        workspace propagation only adds extra outputs (W, W_m) for a forward
207 //        op and connects them to the corresponding backward ops.
208 //
209 //        Terms:
210 //
211 //        Forward op name = name of the op in the forward pass
212 //          where a workspace tensor originates (MaxPool in this example)
213 //        Backward op name = name of the op in the backward pass that receives
214 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
215 //        Slot = Position of the output or input slot that will be
216 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
217 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
218 //
219 //        Question:
220 //
221 //        How do we associate a backward op to a forward op? There can be more
222 //        than one op with the exact same name.
223 //
224 //        In this example, we associate MaxPoolGrad with MaxPool. But there
225 //        could be more than one MaxPool ops. To solve this problem, we look
226 //        for _direct_ edge between a forward op and a backward op (tensor A is
227 //        flowing along this edge in the example).
228 //
229 //        How do we transform forward and backward ops when there is no direct
230 //        edge between them? In such a case, we generate dummy tensors for
231 //        workspace tensors. For the example, transformation of MaxPool will
232 //        be exactly same as it would be when there is a direct edge between
233 //        the forward and the backward op --- it is just that MaxPool won't
234 //        generate any workspace tensor. For MaxPoolGrad, the transformation
235 //        will also be same, but instead of connecting W and W_m with the
236 //        outputs of MaxPool, we will produce dummy tensors for them, and we
237 //        will set workspace_enabled attribute to false.
238 //
239 class MklLayoutRewritePass : public GraphOptimizationPass {
240  public:
MklLayoutRewritePass()241   MklLayoutRewritePass() {
242     // NOTE: names are alphabetically sorted.
243     csinfo_.addn = "AddN";
244     csinfo_.avg_pool = "AvgPool";
245     csinfo_.avg_pool_grad = "AvgPoolGrad";
246     csinfo_.avg_pool3d = "AvgPool3D";
247     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
248     csinfo_.batch_matmul = "BatchMatMul";
249     csinfo_.batch_matmul_v2 = "BatchMatMulV2";
250     csinfo_.bias_add = "BiasAdd";
251     csinfo_.bias_add_grad = "BiasAddGrad";
252     csinfo_.concat = "Concat";
253     csinfo_.concatv2 = "ConcatV2";
254     csinfo_.conjugate_transpose = "ConjugateTranspose";
255     csinfo_.conv2d = "Conv2D";
256     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
257     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
258     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
259     csinfo_.conv2d_grad_filter_with_bias =
260         "__MklDummyConv2DBackpropFilterWithBias";
261     csinfo_.conv3d = "Conv3D";
262     csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
263     csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
264     csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
265     csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
266     csinfo_.depthwise_conv2d_grad_filter =
267         "DepthwiseConv2dNativeBackpropFilter";
268     csinfo_.dequantize = "Dequantize";
269     csinfo_.fused_batch_norm = "FusedBatchNorm";
270     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
271     csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
272     csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
273     csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
274     csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
275     csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
276     csinfo_.fused_conv2d = "_FusedConv2D";
277     csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
278     csinfo_.fused_matmul = "_FusedMatMul";
279     csinfo_.identity = "Identity";
280     csinfo_.leakyrelu = "LeakyRelu";
281     csinfo_.leakyrelu_grad = "LeakyReluGrad";
282     csinfo_.lrn = "LRN";
283     csinfo_.lrn_grad = "LRNGrad";
284     csinfo_.matmul = "MatMul";
285     csinfo_.max_pool = "MaxPool";
286     csinfo_.max_pool_grad = "MaxPoolGrad";
287     csinfo_.max_pool3d = "MaxPool3D";
288     csinfo_.max_pool3d_grad = "MaxPool3DGrad";
289     csinfo_.mkl_conv2d = "_MklConv2D";
290     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
291     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
292     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
293     csinfo_.mkl_conv2d_grad_filter_with_bias =
294         "_MklConv2DBackpropFilterWithBias";
295     csinfo_.mkl_depthwise_conv2d_grad_input =
296         "_MklDepthwiseConv2dNativeBackpropInput";
297     csinfo_.mkl_depthwise_conv2d_grad_filter =
298         "_MklDepthwiseConv2dNativeBackpropFilter";
299     csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
300     csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
301     csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
302     csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
303     csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
304     csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
305     csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
306     csinfo_.mkl_native_fused_depthwise_conv2d =
307         "_MklNativeFusedDepthwiseConv2dNative";
308     csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
309     csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
310     csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
311     csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
312     csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
313     csinfo_.pad = "Pad";
314     csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
315     csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
316     csinfo_.quantized_avg_pool = "QuantizedAvgPool";
317     csinfo_.quantized_concatv2 = "QuantizedConcatV2";
318     csinfo_.quantized_conv2d = "QuantizedConv2D";
319     csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
320     csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
321     csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
322     csinfo_.quantized_conv2d_with_bias_and_requantize =
323         "QuantizedConv2DWithBiasAndRequantize";
324     csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
325     csinfo_.quantized_conv2d_and_relu_and_requantize =
326         "QuantizedConv2DAndReluAndRequantize";
327     csinfo_.quantized_conv2d_with_bias_and_relu =
328         "QuantizedConv2DWithBiasAndRelu";
329     csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
330         "QuantizedConv2DWithBiasAndReluAndRequantize";
331     csinfo_.quantized_max_pool = "QuantizedMaxPool";
332     csinfo_.quantized_conv2d_with_bias_sum_and_relu =
333         "QuantizedConv2DWithBiasSumAndRelu";
334     csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
335         "QuantizedConv2DWithBiasSumAndReluAndRequantize";
336     csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
337         "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
338     csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
339     csinfo_.quantized_matmul_with_bias_and_relu =
340         "QuantizedMatMulWithBiasAndRelu";
341     csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
342         "QuantizedMatMulWithBiasAndReluAndRequantize";
343     csinfo_.quantized_matmul_with_bias_and_dequantize =
344         "QuantizedMatMulWithBiasAndDequantize";
345     csinfo_.quantized_matmul_with_bias_and_requantize =
346         "QuantizedMatMulWithBiasAndRequantize";
347     csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
348     csinfo_.quantized_depthwise_conv2d_with_bias =
349         "QuantizedDepthwiseConv2DWithBias";
350     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
351         "QuantizedDepthwiseConv2DWithBiasAndRelu";
352     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
353         "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
354     csinfo_.quantize_v2 = "QuantizeV2";
355     csinfo_.relu = "Relu";
356     csinfo_.relu_grad = "ReluGrad";
357     csinfo_.relu6 = "Relu6";
358     csinfo_.relu6_grad = "Relu6Grad";
359     csinfo_.requantize = "Requantize";
360     csinfo_.tanh = "Tanh";
361     csinfo_.tanh_grad = "TanhGrad";
362     csinfo_.reshape = "Reshape";
363     csinfo_.slice = "Slice";
364     csinfo_.softmax = "Softmax";
365     csinfo_.split = "Split";
366     csinfo_.transpose = "Transpose";
367     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
368     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
369     // MklInputConversion op is added before it.
370     csinfo_.add = "Add";
371     csinfo_.add_v2 = "AddV2";
372     csinfo_.maximum = "Maximum";
373     csinfo_.mul = "Mul";
374     csinfo_.squared_difference = "SquaredDifference";
375     csinfo_.sub = "Sub";
376     // End - element-wise ops. See note above.
377 
378     const bool native_fmt = NativeFormatEnabled();
379     // NOTE: names are alphabetically sorted.
380     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
381                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
382     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
383                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
384                       GetRewriteCause()});
385     rinfo_.push_back(
386         {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
387          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
388     rinfo_.push_back({csinfo_.avg_pool,
389                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
390                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
391     rinfo_.push_back({csinfo_.avg_pool_grad,
392                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
393                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
394     rinfo_.push_back({csinfo_.avg_pool3d,
395                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
396                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
397     rinfo_.push_back({csinfo_.avg_pool3d_grad,
398                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
399                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
400     rinfo_.push_back({csinfo_.batch_matmul,
401                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
402                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
403     rinfo_.push_back({csinfo_.batch_matmul_v2,
404                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
405                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
406     rinfo_.push_back({csinfo_.concat,
407                       mkl_op_registry::GetMklOpName(csinfo_.concat),
408                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
409     rinfo_.push_back({csinfo_.concatv2,
410                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
411                       CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
412     rinfo_.push_back(
413         {csinfo_.conjugate_transpose,
414          mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
415          CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
416     rinfo_.push_back(
417         {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
418          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
419     rinfo_.push_back({csinfo_.conv2d_with_bias,
420                       native_fmt ? csinfo_.mkl_native_conv2d_with_bias
421                                  : csinfo_.mkl_conv2d_with_bias,
422                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
423                       GetRewriteCause()});
424     rinfo_.push_back({csinfo_.conv2d_grad_filter,
425                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
426                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
427     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
428                       csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
429                       AlwaysRewrite, GetRewriteCause()});
430     rinfo_.push_back({csinfo_.conv2d_grad_input,
431                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
432                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
433     rinfo_.push_back(
434         {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
435          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
436     rinfo_.push_back({csinfo_.conv3d_grad_filter,
437                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
438                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
439     rinfo_.push_back({csinfo_.conv3d_grad_input,
440                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
441                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
442     rinfo_.push_back({csinfo_.depthwise_conv2d,
443                       mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
444                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
445                       GetRewriteCause()});
446     rinfo_.push_back(
447         {csinfo_.depthwise_conv2d_grad_input,
448          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
449          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
450     rinfo_.push_back(
451         {csinfo_.depthwise_conv2d_grad_filter,
452          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
453          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
454     rinfo_.push_back({csinfo_.dequantize,
455                       mkl_op_registry::GetMklOpName(csinfo_.dequantize),
456                       CopyAttrsAll, DequantizeRewrite, GetRewriteCause()});
457     rinfo_.push_back({csinfo_.fused_batch_norm,
458                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
459                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
460     rinfo_.push_back(
461         {csinfo_.fused_batch_norm_grad,
462          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
463          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
464     rinfo_.push_back(
465         {csinfo_.fused_batch_norm_v2,
466          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
467          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
468     rinfo_.push_back(
469         {csinfo_.fused_batch_norm_grad_v2,
470          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
471          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
472 
473     // Using CopyAttrsAll for V3 on CPU, as there are no additional
474     // attributes.
475     rinfo_.push_back(
476         {csinfo_.fused_batch_norm_v3,
477          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
478          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
479     rinfo_.push_back(
480         {csinfo_.fused_batch_norm_grad_v3,
481          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
482          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
483     rinfo_.push_back({csinfo_.fused_batch_norm_ex,
484                       native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
485                                  : csinfo_.mkl_fused_batch_norm_ex,
486                       CopyAttrsAll, FusedBatchNormExRewrite,
487                       GetRewriteCause()});
488     rinfo_.push_back({csinfo_.fused_conv2d,
489                       native_fmt ? csinfo_.mkl_native_fused_conv2d
490                                  : csinfo_.mkl_fused_conv2d,
491                       CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
492                       GetRewriteCause()});
493     rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
494                       native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
495                                  : csinfo_.mkl_fused_depthwise_conv2d,
496                       CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
497                       GetRewriteCause()});
498     rinfo_.push_back({csinfo_.fused_matmul,
499                       native_fmt ? csinfo_.mkl_native_fused_matmul
500                                  : csinfo_.mkl_fused_matmul,
501                       CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
502                       GetRewriteCause()});
503 
504     rinfo_.push_back(
505         {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
506          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
507     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
508                       CopyAttrsAll, LrnRewrite, GetRewriteCause()});
509     rinfo_.push_back({csinfo_.lrn_grad,
510                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
511                       CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
512     rinfo_.push_back({csinfo_.matmul,
513                       mkl_op_registry::GetMklOpName(csinfo_.matmul),
514                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
515     rinfo_.push_back({csinfo_.leakyrelu,
516                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
517                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
518     rinfo_.push_back({csinfo_.leakyrelu_grad,
519                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
520                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
521     rinfo_.push_back(
522         {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
523          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
524     rinfo_.push_back({csinfo_.max_pool_grad,
525                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
526                       CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
527     rinfo_.push_back(
528         {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
529          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
530     rinfo_.push_back({csinfo_.max_pool3d_grad,
531                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
532                       CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
533     rinfo_.push_back(
534         {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
535          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
536     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
537                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
538                       GetRewriteCause()});
539     rinfo_.push_back({csinfo_.pad_with_conv2d,
540                       native_fmt ? csinfo_.mkl_native_pad_with_conv2d
541                                  : csinfo_.mkl_pad_with_conv2d,
542                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
543                       GetRewriteCause()});
544     rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
545                       native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
546                                  : csinfo_.mkl_pad_with_fused_conv2d,
547                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
548                       GetRewriteCause()});
549     rinfo_.push_back({csinfo_.quantized_avg_pool,
550                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
551                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
552     rinfo_.push_back({csinfo_.quantized_concatv2,
553                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
554                       CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
555     rinfo_.push_back({csinfo_.quantized_conv2d,
556                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
557                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
558                       GetRewriteCause()});
559     rinfo_.push_back(
560         {csinfo_.quantized_conv2d_per_channel,
561          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
562          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
563     rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
564                       mkl_op_registry::GetMklOpName(
565                           csinfo_.quantized_conv2d_with_requantize),
566                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
567                       GetRewriteCause()});
568     rinfo_.push_back(
569         {csinfo_.quantized_conv2d_with_bias,
570          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
571          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
572     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
573                       mkl_op_registry::GetMklOpName(
574                           csinfo_.quantized_conv2d_with_bias_and_requantize),
575                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
576                       GetRewriteCause()});
577     rinfo_.push_back(
578         {csinfo_.quantized_conv2d_and_relu,
579          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
580          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
581     rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
582                       mkl_op_registry::GetMklOpName(
583                           csinfo_.quantized_conv2d_and_relu_and_requantize),
584                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
585                       GetRewriteCause()});
586     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
587                       mkl_op_registry::GetMklOpName(
588                           csinfo_.quantized_conv2d_with_bias_and_relu),
589                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
590                       GetRewriteCause()});
591     rinfo_.push_back(
592         {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
593          mkl_op_registry::GetMklOpName(
594              csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
595          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
596     rinfo_.push_back({csinfo_.quantized_max_pool,
597                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
598                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
599     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
600                       mkl_op_registry::GetMklOpName(
601                           csinfo_.quantized_conv2d_with_bias_sum_and_relu),
602                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
603                       GetRewriteCause()});
604     rinfo_.push_back(
605         {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
606          mkl_op_registry::GetMklOpName(
607              csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
608          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
609     rinfo_.push_back(
610         {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
611          mkl_op_registry::GetMklOpName(
612              csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
613          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
614     rinfo_.push_back(
615         {csinfo_.quantized_matmul_with_bias,
616          mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
617          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
618     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
619                       mkl_op_registry::GetMklOpName(
620                           csinfo_.quantized_matmul_with_bias_and_relu),
621                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
622     rinfo_.push_back(
623         {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
624          mkl_op_registry::GetMklOpName(
625              csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
626          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
627     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
628                       mkl_op_registry::GetMklOpName(
629                           csinfo_.quantized_matmul_with_bias_and_requantize),
630                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
631     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
632                       mkl_op_registry::GetMklOpName(
633                           csinfo_.quantized_matmul_with_bias_and_dequantize),
634                       CopyAttrsQuantizedMatMulWithBiasAndDequantize,
635                       AlwaysRewrite});
636     rinfo_.push_back(
637         {csinfo_.quantized_depthwise_conv2d,
638          mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
639          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
640     rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
641                       mkl_op_registry::GetMklOpName(
642                           csinfo_.quantized_depthwise_conv2d_with_bias),
643                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
644                       GetRewriteCause()});
645     rinfo_.push_back(
646         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
647          mkl_op_registry::GetMklOpName(
648              csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
649          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
650     rinfo_.push_back(
651         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
652          mkl_op_registry::GetMklOpName(
653              csinfo_
654                  .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
655          CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
656     rinfo_.push_back({csinfo_.quantize_v2,
657                       mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
658                       CopyAttrsAll, QuantizeOpRewrite, GetRewriteCause()});
659     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
660                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
661     rinfo_.push_back({csinfo_.relu_grad,
662                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
663                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
664     rinfo_.push_back({csinfo_.relu6,
665                       mkl_op_registry::GetMklOpName(csinfo_.relu6),
666                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
667     rinfo_.push_back({csinfo_.relu6_grad,
668                       mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
669                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
670     rinfo_.push_back({csinfo_.requantize,
671                       mkl_op_registry::GetMklOpName(csinfo_.requantize),
672                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
673     // Optimized TanhGrad support exists only in DNNL 1.x.
674     rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
675                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
676     rinfo_.push_back({csinfo_.tanh_grad,
677                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
678                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
679     rinfo_.push_back({csinfo_.reshape,
680                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
681                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
682     rinfo_.push_back(
683         {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
684          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
685     rinfo_.push_back({csinfo_.softmax,
686                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
687                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
688 
689     rinfo_.push_back({csinfo_.squared_difference,
690                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
691                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
692                       GetRewriteCause()});
693     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
694                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
695                       GetRewriteCause()});
696     rinfo_.push_back({csinfo_.transpose,
697                       mkl_op_registry::GetMklOpName(csinfo_.transpose),
698                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
699 
700     // Add info about which ops to add workspace edge to and the slots.
701     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
702     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
703     wsinfo_.push_back(
704         {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
705 
706     // Add a rule for merging nodes
707     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
708                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
709 
710     // Merge Pad and Conv2d, only if the pad op is "Pad"
711     // Doesn't merge if pad op is "PadV2" or "MirrorPad"
712     minfo_.push_back(
713         {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
714 
715     minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
716                       csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
717 
718     if (!native_fmt) {
719       minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
720                         csinfo_.conv2d_grad_filter_with_bias,
721                         GetConv2DBackpropFilterOrBiasAddGrad});
722 
723       // The fusion patterns in "finfo_" that show up first will get applied
724       // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
725       // A->B->C->D to ABCD}, since the first gets applied first, the final
726       // graph will be ABC->D.
727 
728       //
729       // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
730       // (NHWC) + Transpose (NHWC->
731       // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
732       // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
733       // while "fusion" is for 3+ nodes situation.
734       //
735 
736       // Transpose + Conv2d + Transpose:
737       std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
738                                             NCHW::dim::W, NCHW::dim::C};
739       std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
740                                             NHWC::dim::H, NHWC::dim::W};
741       auto CheckForTransposeToNHWC = std::bind(
742           CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
743       auto CheckForConv2dOp =
744           std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
745       auto CheckForTransposeToNCHW = std::bind(
746           CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
747       auto FuseConv2D =
748           std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
749                     std::placeholders::_2, std::placeholders::_3, "NCHW");
750       finfo_.push_back(
751           {"transpose-elimination for Conv2D",
752            {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
753            // CheckForMklOp
754            FuseConv2D,
755            CopyAttrsConv});
756 
757       // Transpose + Conv3d + Transpose:
758       std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
759                                              NCDHW::dim::H, NCDHW::dim::W,
760                                              NCDHW::dim::C};
761       std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
762                                              NDHWC::dim::D, NDHWC::dim::H,
763                                              NDHWC::dim::W};
764 
765       auto CheckForTransposeToNDHWC = std::bind(
766           CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
767       auto CheckForConv3dOp =
768           std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
769       auto CheckForTransposeToNCDHW = std::bind(
770           CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
771       auto FuseConv3D =
772           std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
773                     std::placeholders::_2, std::placeholders::_3, "NCDHW");
774 
775       finfo_.push_back({"transpose-elimination for Conv3D",
776                         {CheckForTransposeToNDHWC, CheckForConv3dOp,
777                          CheckForTransposeToNCDHW},
778                         // CheckForMklOp
779                         FuseConv3D,
780                         CopyAttrsConv});
781 
782       auto CheckForMaxPool3DOp =
783           std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
784       auto FuseMaxPool3D =
785           std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
786                     std::placeholders::_2, std::placeholders::_3, "NCDHW");
787       finfo_.push_back({"transpose-elimination for MaxPool3D",
788                         {CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
789                          CheckForTransposeToNCDHW},
790                         // CheckForMklOp
791                         FuseMaxPool3D,
792                         CopyAttrsPooling});
793     }
794   }
795 
796   // Standard interface to run pass
797   Status Run(const GraphOptimizationPassOptions& options);
798 
799   // Helper function which does most of heavy lifting for rewriting
800   // Mkl nodes to propagate Mkl tensor as additional output
801   //
802   // Extracts common functionality between Run public interface and
803   // test interface.
804   //
805   // @return true, if and only if graph is mutated; false otherwise.
806   bool RunPass(std::unique_ptr<Graph>* g);
807 
808   /// Cause for rewrite
809   /// Currently, we only support 2 causes - either for Mkl layout propagation
810   /// which is the most common case, or for just a name change (used in case
811   /// of ops like MatMul, Transpose, which do not support Mkl layout)
812   enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
813 
814   // Get the op rewrite cause depending on whether native format mode
815   // is enabled or not.
GetRewriteCause()816   RewriteCause GetRewriteCause() {
817     if (NativeFormatEnabled()) {
818       return kRewriteForOpNameChange;
819     } else {
820       return kRewriteForLayoutPropagation;
821     }
822   }
823 
824   /// Structure to specify the name of an original node, its new name after
825   /// rewrite, the number of inputs to the original node, the function to
826   /// be used to copy attributes for the op, and the rule (if any) which
827   /// must hold for rewriting the node
828   typedef struct {
829     string name;      // Original name of op of the node in the graph
830     string new_name;  // New name of the op of the node in the graph
831     // A function handler to copy attributes from an old node to a new node.
832     std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
833     // A rule under which to rewrite this node
834     std::function<bool(const Node*)> rewrite_rule;
835     // Why are we rewriting?
836     RewriteCause rewrite_cause;
837   } RewriteInfo;
838 
839   /// Structure to specify a forward op, a backward op, and the slot numbers
840   /// in the forward and backward ops where we will add a workspace edge.
841   typedef struct {
842     string fwd_op;    // Name of a forward op in the graph
843     string bwd_op;    // Name of a backward op in the graph
844     int fwd_slot;     // Output slot in the forward op node where actual
845                       // output tensor resides
846     int bwd_slot;     // Input slot in the backward op node where actual
847                       // input tensor resides
848     int ws_fwd_slot;  // Output slot in the forward op node where workspace
849                       // edge is added
850     int ws_bwd_slot;  // Input slot in the backward op node where workspace
851                       // edge is added
852   } WorkSpaceInfo;
853 
854   /// Structure to specify information used in node merge of 2 operators
855   typedef struct {
856     string op1;       // Node string for one operator.
857     string op2;       // Node string for second operator.
858     string new_node;  // Name of the node after merge
859     // Function that enables user of the node merger to specify how to find
860     // second operator given the first operator.
861     std::function<Node*(const Node*)> get_node_to_be_merged;
862   } MergeInfo;
863 
864   // Structure to specify information used in node fusion of 3+ operators
865   typedef struct {
866     std::string pattern_name;  // Name to describe this pattern, such as
867                                // "Transpose_Mklop_Transpose".
868     std::vector<std::function<bool(const Node*)> >
869         node_checkers;  // Extra restriction checker for these ops
870     std::function<Status(
871         std::unique_ptr<Graph>*, std::vector<Node*>&,
872         std::function<void(const Node*, NodeBuilder* nb, bool)>)>
873         fuse_func;
874     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
875   } FusionInfo;
876 
877   //
878   // Dimension indices for 2D tensor.
879   //
880   struct NCHW {
881     enum dim { N = 0, C = 1, H = 2, W = 3 };
882   };
883 
884   struct NHWC {
885     enum dim { N = 0, H = 1, W = 2, C = 3 };
886   };
887 
888   //
889   // dimension indices for 3D tensor.
890   //
891   struct NCDHW {
892     enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
893   };
894 
895   struct NDHWC {
896     enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
897   };
898 
899   /// Structure to store all constant strings
900   /// NOTE: names are alphabetically sorted.
901   typedef struct {
902     string addn;
903     string add;
904     string add_v2;
905     string avg_pool;
906     string avg_pool_grad;
907     string avg_pool3d;
908     string avg_pool3d_grad;
909     string batch_matmul;
910     string batch_matmul_v2;
911     string bias_add;
912     string bias_add_grad;
913     string concat;
914     string concatv2;
915     string conjugate_transpose;
916     string conv2d;
917     string conv2d_with_bias;
918     string conv2d_grad_input;
919     string conv2d_grad_filter;
920     string conv2d_grad_filter_with_bias;
921     string conv3d;
922     string conv3d_grad_input;
923     string conv3d_grad_filter;
924     string depthwise_conv2d;
925     string depthwise_conv2d_grad_input;
926     string depthwise_conv2d_grad_filter;
927     string dequantize;
928     string fused_batch_norm;
929     string fused_batch_norm_grad;
930     string fused_batch_norm_ex;
931     string fused_batch_norm_v2;
932     string fused_batch_norm_grad_v2;
933     string fused_batch_norm_v3;
934     string fused_batch_norm_grad_v3;
935     string fused_conv2d;
936     string fused_depthwise_conv2d;
937     string fused_matmul;
938     string identity;
939     string leakyrelu;
940     string leakyrelu_grad;
941     string lrn;
942     string lrn_grad;
943     string matmul;
944     string max_pool;
945     string max_pool_grad;
946     string max_pool3d;
947     string max_pool3d_grad;
948     string maximum;
949     string mkl_conv2d;
950     string mkl_conv2d_grad_input;
951     string mkl_conv2d_grad_filter;
952     string mkl_conv2d_grad_filter_with_bias;
953     string mkl_conv2d_with_bias;
954     string mkl_depthwise_conv2d_grad_input;
955     string mkl_depthwise_conv2d_grad_filter;
956     string mkl_fused_batch_norm_ex;
957     string mkl_fused_conv2d;
958     string mkl_fused_depthwise_conv2d;
959     string mkl_fused_matmul;
960     string mkl_native_conv2d_with_bias;
961     string mkl_native_fused_batch_norm_ex;
962     string mkl_native_fused_conv2d;
963     string mkl_native_fused_depthwise_conv2d;
964     string mkl_native_fused_matmul;
965     string mkl_native_pad_with_conv2d;
966     string mkl_native_pad_with_fused_conv2d;
967     string mkl_pad_with_conv2d;
968     string mkl_pad_with_fused_conv2d;
969     string mul;
970     string pad;
971     string pad_with_conv2d;
972     string pad_with_fused_conv2d;
973     string quantized_avg_pool;
974     string quantized_conv2d;
975     string quantized_conv2d_per_channel;
976     string quantized_conv2d_with_requantize;
977     string quantized_conv2d_with_bias;
978     string quantized_conv2d_with_bias_and_requantize;
979     string quantized_conv2d_and_relu;
980     string quantized_conv2d_and_relu_and_requantize;
981     string quantized_conv2d_with_bias_and_relu;
982     string quantized_conv2d_with_bias_and_relu_and_requantize;
983     string quantized_concatv2;
984     string quantized_max_pool;
985     string quantized_conv2d_with_bias_sum_and_relu;
986     string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
987     string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
988     string quantized_matmul_with_bias;
989     string quantized_matmul_with_bias_and_relu;
990     string quantized_matmul_with_bias_and_relu_and_requantize;
991     string quantized_matmul_with_bias_and_requantize;
992     string quantized_matmul_with_bias_and_dequantize;
993     string quantized_depthwise_conv2d;
994     string quantized_depthwise_conv2d_with_bias;
995     string quantized_depthwise_conv2d_with_bias_and_relu;
996     string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
997     string quantize_v2;
998     string relu;
999     string relu_grad;
1000     string relu6;
1001     string relu6_grad;
1002     string requantize;
1003     string tanh;
1004     string tanh_grad;
1005     string transpose;
1006     string reshape;
1007     string slice;
1008     string softmax;
1009     string split;
1010     string squared_difference;
1011     string sub;
1012   } ConstStringsInfo;
1013 
1014  private:
1015   /// Maintain info about nodes to rewrite
1016   std::vector<RewriteInfo> rinfo_;
1017 
1018   /// Maintain info about nodes to add workspace edge
1019   std::vector<WorkSpaceInfo> wsinfo_;
1020 
1021   /// Maintain info about nodes to be merged
1022   std::vector<MergeInfo> minfo_;
1023 
1024   /// Maintain info about nodes to be fused
1025   std::vector<FusionInfo> finfo_;
1026 
1027   /// Maintain structure of constant strings
1028   static ConstStringsInfo csinfo_;
1029 
1030  private:
1031   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
1032   // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const1033   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
1034     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
1035   }
1036 
1037   // Get length of a list in 'n' if 'arg' is of list type. Refer to
1038   // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)1039   inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
1040     CHECK_EQ(ArgIsList(arg), true);
1041     int N = 0;
1042     const string attr_name = !arg.type_list_attr().empty()
1043                                  ? arg.type_list_attr()
1044                                  : arg.number_attr();
1045     if (!arg.type_list_attr().empty()) {
1046       std::vector<DataType> value;
1047       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1048       N = value.size();
1049     } else {
1050       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1051     }
1052     return N;
1053   }
1054 
1055   // Can op represented by node 'n' run on DEVICE_CPU?
1056   // Op can run on CPU with MKL if the runtime assigned device or the
1057   // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1058   bool CanOpRunOnCPUDevice(const Node* n) {
1059     bool result = true;
1060     string reason;
1061 
1062     // Substring that should be checked for in device name for CPU device.
1063     const char* const kCPUDeviceSubStr = "CPU";
1064 
1065     // If Op has been specifically assigned to a non-CPU device, then No.
1066     if (!n->assigned_device_name().empty() &&
1067         !absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
1068       result = false;
1069       reason = "Op has been assigned a runtime device that is not CPU.";
1070     }
1071 
1072     // If user has specifically assigned this op to a non-CPU device, then No.
1073     if (!n->def().device().empty() &&
1074         !absl::StrContains(n->def().device(), kCPUDeviceSubStr)) {
1075       result = false;
1076       reason = "User has assigned a device that is not CPU.";
1077     }
1078 
1079     if (result == false) {
1080       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1081               << n->type_string() << ", reason: " << reason;
1082     }
1083 
1084     // Otherwise Yes.
1085     return result;
1086   }
1087 
1088   // Return a node that can be merged with input node 'n'
1089   //
1090   // @return pointer to the node if we can find such a
1091   // node. Otherwise, it returns nullptr.
1092   Node* CheckForNodeMerge(const Node* n) const;
1093 
1094   // Merge node 'm' with node 'n'.
1095   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1096   // Conv2DBackpropFilter.
1097   //
1098   // Input nodes m and n may be deleted if the call to
1099   // this function is successful. Attempt to use the pointers
1100   // after the call to function may result in undefined behaviors.
1101   //
1102   // @input g - input graph, m - graph node, n - graph node to be merged with m
1103   // @return Status::OK(), if merging is successful and supported.
1104   //         Returns appropriate Status error code otherwise.
1105   //         Graph is updated in case nodes are merged. Otherwise, it is
1106   //         not updated.
1107   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1108 
1109   // Helper function to merge different nodes
1110   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1111   Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1112   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1113                                                   Node* m, Node* n);
1114 
1115   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1116   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1117   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1118   // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1119   static Node* GetConv2DOrBiasAdd(const Node* m) {
1120     DCHECK(m);
1121     Node* n = nullptr;
1122 
1123     DataType T_m;
1124     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1125 
1126 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1127     // Don't try to merge if datatype is not DT_FLOAT
1128     if (T_m != DT_FLOAT) return n;
1129 #else
1130     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1131     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1132 #endif
1133 
1134     if (m->type_string() == csinfo_.bias_add) {
1135       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1136       TF_CHECK_OK(m->input_node(0, &n));
1137     } else {
1138       CHECK_EQ(m->type_string(), csinfo_.conv2d);
1139       // Go over all output edges and search for BiasAdd Node.
1140       // 0th input of BiasAdd is Conv2D.
1141       for (const Edge* e : m->out_edges()) {
1142         if (!e->IsControlEdge() &&
1143             e->dst()->type_string() == csinfo_.bias_add &&
1144             e->dst_input() == 0) {
1145           n = e->dst();
1146           break;
1147         }
1148       }
1149     }
1150 
1151     if (n == nullptr) {
1152       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1153               << "Conv2D and BiasAdd node for merging. Input node: "
1154               << m->DebugString();
1155     }
1156 
1157     return n;
1158   }
1159 
1160   // Find Pad or Conv2D node that can be merged with input node 'm'.
1161   // If input 'm' is Pad, then check if there exists Conv2D node that can be
1162   // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1163   // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1164   static Node* GetPadOrConv2D(const Node* m) {
1165     DCHECK(m);
1166     Node* n = nullptr;
1167 
1168     DataType T_m;
1169     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1170 
1171 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1172     // Don't try to merge if datatype is not DT_FLOAT
1173     if (T_m != DT_FLOAT) return n;
1174 #else
1175     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1176     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1177 #endif
1178 
1179     const Node* conv_node;
1180     if (m->type_string() == csinfo_.pad) {
1181       // If m is Pad, then Conv2D is the output of Pad.
1182       for (const Edge* e : m->out_edges()) {
1183         if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1184           n = e->dst();
1185           conv_node = n;
1186           break;
1187         }
1188       }
1189     } else {
1190       DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1191       // If m is conv2D, Go over all input edges
1192       // and search for Pad  Node.
1193       for (const Edge* e : m->in_edges()) {
1194         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1195           n = e->src();
1196           conv_node = m;
1197           break;
1198         }
1199       }
1200     }
1201     // Check if only VALID type of padding is used
1202     // or not.
1203     if (n != nullptr) {
1204       string padding;
1205       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1206       if (padding != "VALID")
1207         // Then do not merge.
1208         // Only VALID type of padding in conv op can be
1209         // merged with Pad op.
1210         n = nullptr;
1211     } else {
1212       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1213               << "Pad and Conv2D node for merging. Input node: "
1214               << m->DebugString();
1215     }
1216 
1217     return n;
1218   }
1219 
1220   // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1221   // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1222   // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1223   // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1224   static Node* GetPadOrFusedConv2D(const Node* m) {
1225     DCHECK(m);
1226     Node* n = nullptr;
1227 
1228     const Node* conv_node;
1229     if (m->type_string() == csinfo_.pad) {
1230       // If m is Pad, then _FusedConv2D is the output of Pad.
1231       for (const Edge* e : m->out_edges()) {
1232         if (!e->IsControlEdge() &&
1233             e->dst()->type_string() == csinfo_.fused_conv2d) {
1234           n = e->dst();
1235           conv_node = n;
1236           break;
1237         }
1238       }
1239     } else {
1240       DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1241       // If m is _FusedConv2D, Go over all input edges
1242       // and search for Pad node.
1243       for (const Edge* e : m->in_edges()) {
1244         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1245           n = e->src();
1246           conv_node = m;
1247           break;
1248         }
1249       }
1250     }
1251     // Check if only VALID type of padding is used or not.
1252     if (n != nullptr) {
1253       string padding;
1254       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1255       if (padding != "VALID") {
1256         // Then do not merge.
1257         n = nullptr;
1258         VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1259                 << "nodes but cannot merge them. Only conv ops with padding "
1260                 << "type VALID can be merged with Pad op Input node: "
1261                 << m->DebugString();
1262       }
1263     } else {
1264       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1265               << "Pad and _FusedConv2D node for merging. Input node: "
1266               << m->DebugString();
1267     }
1268 
1269     return n;
1270   }
1271 
1272   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1273   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1274   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1275   // then check if there exists Conv2DBackpropFilter node that can be merged
1276   // with 'm'.
1277   //
1278   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1279   // would look like:
1280   //
1281   // _ = Conv2DBackpropFilter(F, _, G)
1282   // _ = BiasAddGrad(G)
1283   //
1284   // So 1st input of BiasAddGrad connects with 3rd input of
1285   // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1286   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1287     DCHECK(m);
1288     Node* n = nullptr;
1289     const Node* conv2d_backprop_filter = nullptr;
1290 
1291     DataType T_m;
1292     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1293 
1294 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1295     // Don't try to merge if datatype is not DT_FLOAT
1296     if (T_m != DT_FLOAT) return n;
1297 #else
1298     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1299     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1300 #endif
1301 
1302     if (m->type_string() == csinfo_.bias_add_grad) {
1303       // Get 1st input 'g' of BiasAddGrad.
1304       Node* g = nullptr;
1305       TF_CHECK_OK(m->input_node(0, &g));
1306       // Now traverse all outgoing edges from g that have destination node as
1307       // Conv2DBackpropFilter.
1308       for (const Edge* e : g->out_edges()) {
1309         if (!e->IsControlEdge() &&
1310             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1311             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1312           n = e->dst();
1313           conv2d_backprop_filter = n;
1314           break;
1315         }
1316       }
1317     } else {
1318       conv2d_backprop_filter = m;
1319       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1320       // Get 3rd input 'g' of Conv2DBackpropFilter.
1321       Node* g = nullptr;
1322       TF_CHECK_OK(m->input_node(2, &g));
1323       // Now traverse all outgoing edges from g that have destination node as
1324       // BiasAddGrad.
1325       for (const Edge* e : g->out_edges()) {
1326         if (!e->IsControlEdge() &&
1327             e->dst()->type_string() == csinfo_.bias_add_grad &&
1328             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1329           n = e->dst();
1330           break;
1331         }
1332       }
1333     }
1334 
1335     // Do not merge if padding type is EXPLICIT.
1336     // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1337     if (conv2d_backprop_filter != nullptr) {
1338       string padding;
1339       TF_CHECK_OK(
1340           GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1341       if (padding == "EXPLICIT") {
1342         // Then do not merge.
1343         VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1344                 << "and BiasAddGrad nodes but cannot merge them. "
1345                 << "EXPLICIT padding is not supported now. "
1346                 << conv2d_backprop_filter->DebugString();
1347         return nullptr;
1348       }
1349     }
1350 
1351     if (n == nullptr) {
1352       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1353               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1354               << "Input node: " << m->DebugString();
1355     }
1356     return n;
1357   }
1358 
1359   // Return a node that can be fused with input node 'n'
1360   //
1361   // @return tuple. If we can find such nodes, the first
1362   // element of the tuple is a true. Otherwise, it's false.
1363   std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1364   CheckForNodeFusion(Node* n) const;
1365 
1366   // Fuse nodes in the vector "nodes"
1367   Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1368                   const MklLayoutRewritePass::FusionInfo fi);
1369 
1370   // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1371   // mklop("NCHW").
1372   // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1373   static Status FuseTransposeMklOpTranspose(
1374       std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1375       std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1376       string data_format);
1377 
CheckForTranspose(const Node * node,std::vector<int> perm)1378   static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1379     // Check if node's type is "Transpose"
1380     if (node->type_string() != "Transpose") return false;
1381 
1382     // If "Transpose" has multiple output data edges, also don't fuse it.
1383     if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1384 
1385     // Check if has out control edge. If true, this is a training graph.
1386     // Currently we focus on inference and do no fusion in training.
1387     // Note: this constraint will eventually be removed, if we enabled this
1388     // fusion for training
1389     // in the future.
1390     for (const Edge* e : node->out_edges()) {
1391       if (e->IsControlEdge()) {
1392         return false;
1393       }
1394     }
1395 
1396     // If "Transpose" has input control edges, don't fuse on it.
1397     for (const Edge* e : node->in_edges()) {
1398       if (e->IsControlEdge()) {
1399         return false;
1400       }
1401     }
1402 
1403     // We compared the tensor containing the permutation order ("perm_node")
1404     // with our desired order ("perm"). If they're exactly match, this check
1405     // succeed and returns true.
1406     for (const Edge* e : node->in_edges()) {
1407       if (!e->IsControlEdge()) {
1408         const Node* perm_node = e->src();
1409 
1410         const int kPermTensorIndex = 1;
1411         if (perm_node->type_string() == "Const" &&
1412             e->dst_input() == kPermTensorIndex) {
1413           // we find the "perm" node, now try to retrieve its value.
1414           const TensorProto* proto = nullptr;
1415           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1416 
1417           DataType type;
1418           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1419 
1420           Tensor tensor;
1421           if (!tensor.FromProto(*proto)) {
1422             TF_CHECK_OK(errors::InvalidArgument(
1423                 "Could not construct Tensor from TensorProto in node: ",
1424                 node->name()));
1425             return false;
1426           }
1427           // Current fusion only supports 4D or 5D tensors according to `perm`
1428           // vector, return false otherwise.
1429           if (tensor.dim_size(0) != perm.size()) return false;
1430           DCHECK_EQ(tensor.dims(), 1);
1431           if (type == DT_INT32) {
1432             const auto tensor_content = tensor.flat<int>().data();
1433             for (int i = 0; i < perm.size(); ++i)
1434               if (tensor_content[i] != perm[i]) return false;
1435             return true;
1436           } else if (type == DT_INT64) {
1437             const auto tensor_content = tensor.flat<int64>().data();
1438             for (int i = 0; i < perm.size(); ++i)
1439               if (tensor_content[i] != perm[i]) return false;
1440             return true;
1441           }
1442           return false;
1443         }
1444       }
1445     }
1446     return false;
1447   }
1448 
CheckForMklOp(const Node * node,string name="")1449   static bool CheckForMklOp(const Node* node, string name = "") {
1450     if (node == nullptr) return false;
1451 
1452     if (!name.empty() && node->type_string() != name) {
1453       return false;
1454     }
1455 
1456     // if mklop has multiple outputs, don't fuse it.
1457     if (node->num_outputs() > 1) return false;
1458 
1459     if (node->out_edges().size() > 1) return false;
1460 
1461     DataType T;
1462     TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1463     return mkl_op_registry::IsMklLayoutDependentOp(
1464         mkl_op_registry::GetMklOpName(node->type_string()), T);
1465   }
1466 
1467   // Check if the node 'n' has any applicable rewrite rule
1468   // We check for 2 scenarios for rewrite.
1469   //
1470   // @return RewriteInfo* for the applicable rewrite rule
1471   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1472   const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1473 
1474   // Default rewrite rule to be used in scenario 1 for rewrite.
1475   // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1476   static bool AlwaysRewrite(const Node* n) { return true; }
1477 
1478   // Rewrite rule which considers "context" of the current node to decide if we
1479   // should rewrite. By "context" we currently mean all the inputs of current
1480   // node. The idea is if none of the inputs of current node are not MKL nodes,
1481   // then rewriting current node to MKL node _may not_ offer any performance
1482   // improvement.
1483   //
1484   // One such case is element-wise ops. For such ops, we reuse the Eigen
1485   // implementation and pass the MKL metadata tensor through so we can avoid
1486   // conversions. However, if all incoming edges are in TF format, we don't
1487   // need all this overhead, so replace the elementwise node only if at least
1488   // one of its parents is a MKL node.
1489   //
1490   // More generally, all memory- or IO-bound ops (such as Identity) may fall
1491   // under this category.
1492   //
1493   // @input - Input graph node to be rewritten
1494   // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1495   static bool RewriteIfAtleastOneMklInput(const Node* n) {
1496     DataType T;
1497     if (GetNodeAttr(n->def(), "T", &T).ok() &&
1498         mkl_op_registry::IsMklOp(
1499             mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1500       for (auto e : n->in_edges()) {
1501         if (e->IsControlEdge()) continue;
1502         if (mkl_op_registry::IsMklOp(e->src())) {
1503           return true;
1504         }
1505       }
1506     }
1507     return false;
1508   }
1509 
MatMulRewrite(const Node * n)1510   static bool MatMulRewrite(const Node* n) {
1511     DataType T;
1512     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1513     if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1514       VLOG(2) << "Rewriting MatMul to _MklMatMul";
1515       return true;
1516     }
1517     return false;
1518   }
1519   // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1520   static bool ConcatV2Rewrite(const Node* n) {
1521     DataType T;
1522     TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1523     return (T == DT_INT32);
1524   }
1525 
DequantizeRewrite(const Node * n)1526   static bool DequantizeRewrite(const Node* n) {
1527     DCHECK(n);
1528     Node* input = nullptr;
1529     TF_CHECK_OK(n->input_node(0, &input));
1530     string mode_string;
1531     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1532     if (mode_string != "SCALED") {
1533       VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1534               << "This case is not optimized by Intel MKL kernel, thus using "
1535                  "Eigen op for Dequantize op.";
1536       return false;
1537     }
1538     if (input->IsConstant()) {
1539       VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1540               << "could possibly be a filter. "
1541               << "This case is not supported by Intel MKL kernel, thus using "
1542                  "Eigen op for Dequantize op.";
1543       return false;
1544     }
1545     return true;
1546   }
1547 
1548   // Rewrite rule for _FusedMatMul.
1549   // @return - true (no transpose attribute for input 1);
1550   //           false otherwise.
FusedMatMulRewrite(const Node * n)1551   static bool FusedMatMulRewrite(const Node* n) {
1552     bool trans_a;
1553 
1554     // Do not rewrite with transpose attribute because reorder has performance
1555     // impact.
1556     TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1557 
1558     return !trans_a;
1559   }
1560 
1561   // Check if we are performing pooling on depth or batch. If it is, then we
1562   // do not rewrite MaxPool node to Mkl version.
1563   // @return - true (if it is not a depth/batch wise pooling case);
1564   //           false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1565   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1566     DCHECK(n);
1567 
1568     string data_format_str;
1569     TensorFormat data_format;
1570     std::vector<int32> ksize, strides;
1571     TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1572     TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1573     TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1574     bool result = FormatFromString(data_format_str, &data_format);
1575     DCHECK(result);
1576 
1577     // Condition that specifies non-batch-wise and non-depth-wise pooling.
1578     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1579         GetTensorDim(strides, data_format, 'N') == 1 &&
1580         GetTensorDim(ksize, data_format, 'C') == 1 &&
1581         GetTensorDim(strides, data_format, 'C') == 1) {
1582       return true;
1583     }
1584 
1585     return false;
1586   }
1587 
1588   // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1589   // path. The unoptimized path is slow. Thus we don't rewrite the node
1590   // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1591   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1592   static bool LrnRewrite(const Node* n) {
1593     DCHECK(n);
1594 
1595     int depth_radius;
1596     TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1597 
1598     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1599     // and use eigen node instead
1600     if (depth_radius == 2) {
1601       return true;
1602     }
1603     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1604             << "case is not optimized by Intel MKL, thus using Eigen op"
1605             << "for LRN ";
1606 
1607     return false;
1608   }
1609 
LrnGradRewrite(const Node * n)1610   static bool LrnGradRewrite(const Node* n) {
1611     DCHECK(n);
1612     bool do_rewrite = false;
1613 
1614     for (const Edge* e : n->in_edges()) {
1615       // Rewrite only if there is corresponding LRN, i.e workspace is available
1616       if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1617           e->src()->type_string() ==
1618               mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1619           e->src_output() == 0) {
1620         do_rewrite = true;
1621         break;
1622       }
1623     }
1624     return do_rewrite;
1625   }
1626 
1627   // MKL-DNN's LeakyRelu(feature) = feature          (if feature > 0), or
1628   //                                feature * alpha  (otherwise),
1629   // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1630   // These two algorithms are not consistent when alpha > 1,
1631   // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1632   static bool LeakyReluRewrite(const Node* n) {
1633     DCHECK(n);
1634 
1635     float alpha;
1636     bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1637     DCHECK(has_attr);
1638 
1639     // If the alpha of LeakyRelu is less than 1, rewrite the node.
1640     // Otherwise eigen node is used instead.
1641     if (alpha <= 1) {
1642       return true;
1643     }
1644     VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1645             << "which case is not optimized by Intel MKL, thus using Eigen op"
1646             << "for LeakyRelu ";
1647 
1648     return false;
1649   }
1650 
QuantizeOpRewrite(const Node * n)1651   static bool QuantizeOpRewrite(const Node* n) {
1652     DCHECK(n);
1653     Node* filter_node = nullptr;
1654     TF_CHECK_OK(n->input_node(0, &filter_node));
1655     bool narrow_range = false;
1656     int axis = -1;
1657     string mode_string;
1658     string round_mode_string;
1659     DataType type;
1660     TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1661     TryGetNodeAttr(n->def(), "axis", &axis);
1662     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1663     TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1664     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1665 
1666     if (narrow_range) {
1667       VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1668               << "This case is not optimized by Intel MKL, "
1669               << "thus using Eigen op for Quantize op ";
1670       return false;
1671     }
1672     if (axis != -1) {
1673       VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1674               << "per slice quantization."
1675               << "This case is not optimized by Intel MKL, "
1676               << "thus using Eigen op for Quantize op ";
1677       return false;
1678     }
1679     if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1680           (mode_string == "MIN_FIRST"))) {
1681       VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1682               << "rounding mode is not HALF_TO_EVEN. "
1683               << "This case is not optimized by Intel MKL, thus using Eigen op"
1684               << "for Quantize op ";
1685       return false;
1686     }
1687     if (filter_node->IsConstant()) {
1688       VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1689               << "is a constant. "
1690               << "This case is not supported by the kernel, thus using Eigen op"
1691               << "for Quantize op ";
1692 
1693       return false;
1694     }
1695     if (mode_string == "MIN_FIRST") {
1696       if (type != DT_QUINT8) {
1697         VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1698                 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1699                 << "thus using Eigen op for Quantize op ";
1700         return false;
1701       }
1702     }
1703     return true;
1704   }
1705 
MaxpoolGradRewrite(const Node * n)1706   static bool MaxpoolGradRewrite(const Node* n) {
1707     DCHECK(n);
1708     bool do_rewrite = false;
1709     for (const Edge* e : n->in_edges()) {
1710       // Rewrite only if there is corresponding Maxpool, i.e workspace is
1711       // available
1712       if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1713           e->dst_input() == 1 &&
1714           e->src()->type_string() ==
1715               mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1716           e->src_output() == 0) {
1717         do_rewrite = true;
1718         break;
1719       }
1720     }
1721     return do_rewrite;
1722   }
1723 
Maxpool3DGradRewrite(const Node * n)1724   static bool Maxpool3DGradRewrite(const Node* n) {
1725     DCHECK(n);
1726     for (const Edge* e : n->in_edges()) {
1727       // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1728       // available
1729       if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1730           e->dst_input() == 1 &&
1731           e->src()->type_string() ==
1732               mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1733           e->src_output() == 0) {
1734         return true;
1735       }
1736     }
1737     return false;
1738   }
1739 
FusedBatchNormV3Rewrite(const Node * n)1740   static bool FusedBatchNormV3Rewrite(const Node* n) {
1741     DCHECK(n);
1742     if (Check5DFormat(n->def())) {
1743       VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1744               << "support 5D tensors.";
1745       return false;
1746     }
1747     return true;
1748   }
1749 
FusedBatchNormExRewrite(const Node * n)1750   static bool FusedBatchNormExRewrite(const Node* n) {
1751     DCHECK(n);
1752 
1753     int num_side_inputs;
1754     TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1755     string activation_mode;
1756     TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1757 
1758     // if the num_side_inputs is not 0, don't rewrite the node.
1759     if (num_side_inputs != 0) {
1760       VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1761               << "larger than 0 is not optimized by Intel MKL.";
1762       return false;
1763     }
1764 
1765     // if the activation_mode is not 'Relu', don't rewrite the node.
1766     if (activation_mode != "Relu") {
1767       VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1768               << "supported by Intel MKL.";
1769       return false;
1770     }
1771 
1772     return true;
1773   }
1774 
FusedConv2DRewrite(const Node * n)1775   static bool FusedConv2DRewrite(const Node* n) {
1776     // MKL DNN currently doesn't support all fusions that grappler fuses
1777     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1778     // it includes those we support.
1779     DataType T;
1780     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1781         !mkl_op_registry::IsMklLayoutDependentOp(csinfo_.mkl_fused_conv2d, T)) {
1782       return false;
1783     }
1784 
1785     std::vector<string> fused_ops;
1786     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1787     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1788             fused_ops == std::vector<string>{"Relu"} ||
1789             fused_ops == std::vector<string>{"Relu6"} ||
1790             fused_ops == std::vector<string>{"Elu"} ||
1791             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1792             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1793             fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1794             fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1795             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1796             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1797             fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1798             fused_ops == std::vector<string>{"LeakyRelu"} ||
1799             fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1800             fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"});
1801   }
1802 
FusedDepthwiseConv2DRewrite(const Node * n)1803   static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1804     // MKL DNN currently doesn't support all fusions that grappler fuses
1805     // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1806     // _FusedDepthwiseConv2DNative only if it includes those we support.
1807     DataType T;
1808     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1809         !mkl_op_registry::IsMklLayoutDependentOp(
1810             csinfo_.mkl_fused_depthwise_conv2d, T)) {
1811       return false;
1812     }
1813 
1814     std::vector<string> fused_ops;
1815     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1816     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1817             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1818             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1819             fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1820   }
1821 
1822   // Rewrites input node to a new node specified by its matching rewrite info.
1823   //
1824   // Method first searches matching rewrite info for input node and then
1825   // uses that info to rewrite.
1826   //
1827   // Input node may be deleted in case of rewrite. Attempt to use the node
1828   // after the call can result in undefined behaviors.
1829   //
1830   // @input  g - input graph, n - Node to be rewritten,
1831   //         ri - matching rewriteinfo
1832   // @return Status::OK(), if the input node is rewritten;
1833   //         Returns appropriate Status error code otherwise.
1834   //         Graph is updated in case the input node is rewritten.
1835   //         Otherwise, it is not updated.
1836   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1837 
1838   // Rewrites input node to just change its operator name. The number of
1839   // inputs to the node and the number of outputs remain the same. Attributes
1840   // of the new node could be copied from attributes of the old node or
1841   // modified. copy_attrs field of RewriteInfo controls this.
1842   //
1843   // Conceptually, it allows us to rewrite:
1844   //
1845   //        f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1846   //
1847   // Attributes can be altered without any restrictions --- they could be
1848   // copied, modified, or deleted completely.
1849   //
1850   // @input  g - input graph, orig_node - Node to be rewritten,
1851   //         ri - matching rewriteinfo
1852   // @output new_node - points to newly created node
1853   // @return Status::OK(), if the input node is rewritten;
1854   //         Returns appropriate Status error code otherwise.
1855   //         Graph is only updated when the input node is rewritten.
1856   Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1857                                         const Node* orig_node, Node** new_node,
1858                                         const RewriteInfo* ri);
1859 
1860   // Rewrites input node to enable MKL layout propagation. Please also refer to
1861   // documentation for the function RewriteNodeForJustOpNameChange() to
1862   // understand what it means.
1863   //
1864   // @input  g - input graph, orig_node - Node to be rewritten,
1865   //         ri - matching rewriteinfo
1866   // @output new_node - points to newly created node
1867   // @return Status::OK(), if the input node is rewritten;
1868   //         Returns appropriate Status error code otherwise.
1869   //         Graph is updated in case the input node is rewritten.
1870   //         Otherwise, it is not updated.
1871   Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1872                                          const Node* orig_node, Node** new_node,
1873                                          const RewriteInfo* ri);
1874 
1875   // Get nodes that will feed a list of TF tensors to the new
1876   // node that we are constructing.
1877   //
1878   // @input g - input graph,
1879   // @input inputs - inputs to old node that we are using for constructing
1880   //                 new inputs,
1881   // @input input_idx - the index in the 'inputs' vector pointing to the
1882   //                    current input that we have processed so far
1883   // @output input_idx - index will be incremented by the number of nodes
1884   //                     from 'inputs' that are processed
1885   // @input list_length - The expected length of list of TF tensors
1886   // @output output_nodes - the list of new nodes creating TF tensors
1887   //
1888   // @return None
1889   void GetNodesProducingTFTensorList(
1890       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1891       int* input_idx, int list_length,
1892       std::vector<NodeBuilder::NodeOut>* output_nodes);
1893 
1894   // Get nodes that will feed a list of Mkl tensors to the new
1895   // node that we are constructing.
1896   //
1897   // @input g - input graph,
1898   // @input orig_node - Original node that we are rewriting
1899   // @input inputs - inputs to old node that we are using for constructing
1900   //                 new inputs,
1901   // @input input_idx - the index in the 'inputs' vector pointing to the
1902   //                    current input that we have processed so far
1903   // @output input_idx - index will be incremented by the number of nodes
1904   //                     from 'inputs' that are processed
1905   // @input list_length - The expected length of list of Mkl tensors
1906   // @output output_nodes - the list of new nodes creating Mkl tensors
1907   //
1908   // @return None
1909   void GetNodesProducingMklTensorList(
1910       std::unique_ptr<Graph>* g, const Node* orig_node,
1911       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1912       int* input_idx, int list_length,
1913       std::vector<NodeBuilder::NodeOut>* output_nodes);
1914 
1915   // Get a node that will feed an Mkl tensor to the new
1916   // node that we are constructing. The output node could be (1) 'n'
1917   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1918   // if 'n' is not an Mkl layer.
1919   //
1920   // @input g - input graph,
1921   // @input orig_node - Original node that we are rewriting,
1922   // @input n - Node based on which we are creating Mkl node,
1923   // @input n_output_slot - the output slot of node 'n'
1924   //            which is feeding to the node that we are constructing
1925   // @output mkl_node - the new node that will feed Mkl tensor
1926   // @output mkl_node_output_slot - the slot number of mkl_node that
1927   //                                will feed the tensor
1928   // @return None
1929   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1930                                  const Node* orig_node, Node* n,
1931                                  int n_output_slot, Node** mkl_node,
1932                                  int* mkl_node_output_slot);
1933 
1934   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1935   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1936   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1937   // producing workspace edges if 'are_workspace_tensors_available' is true.
1938   // Otherwise, 'workspace_tensors' is empty vector.
1939   //
1940   // For details, refer to 'Ordering of inputs after rewriting' section in the
1941   // documentation above.
1942   //
1943   // Returns Status::OK() if setting up inputs is successful, otherwise
1944   // returns appropriate status code.
1945   int SetUpContiguousInputs(
1946       std::unique_ptr<Graph>* g,
1947       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1948       NodeBuilder* nb, const Node* old_node,
1949       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1950       bool are_workspace_tensors_available);
1951 
1952   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1953   // in graph 'g'. Original node is input in 'orig_node'.
1954   //
1955   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1956   // section in the documentation above.
1957   //
1958   // Returns Status::OK() if setting up inputs is successful, otherwise
1959   // returns appropriate status code.
1960   Status SetUpInputs(std::unique_ptr<Graph>* g,
1961                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1962                      NodeBuilder* nb, const Node* orig_node);
1963 
1964   // Create new inputs by copying old inputs 'inputs' for the rewritten node
1965   // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1966   // used in the context of rewrite for just operator name change in which
1967   // inputs of old operator and new operator are same.
1968   //
1969   // Returns Status::OK() if setting up inputs is successful, otherwise
1970   // returns appropriate status code.
1971   Status CopyInputs(const Node* orig_node,
1972                     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1973                     NodeBuilder* nb);
1974 
1975   // Add workspace edge on the input or output side of Node 'orig_node' by using
1976   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1977   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1978   // tensors, if they need to be added, will be set into these tensors.
1979   // If we set workspace tensors, then are_ws_tensors_added should be true.
1980   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1981                                 const Node* orig_node, NodeBuilder* nb,
1982                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1983                                 bool* are_ws_tensors_added);
1984 
1985   // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1986   // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1987   // 'g'. Returns true if fixup was done; otherwise, it returns false.
1988   bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1989                                   const Edge* e_metadata);
1990 
1991   // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1992   // connected? If not, then fix them. This is needed because a graph may have
1993   // some input Mkl metadata edges incorrectly setup after node merge and
1994   // rewrite passes. This could happen because GetReversePostOrder function may
1995   // not provide topologically sorted order if a graph contains cycles. The
1996   // function returns true if at least one Mkl metadata edge for node 'n' was
1997   // fixed. Otherwise, it returns false.
1998   //
1999   // Example:
2000   //
2001   // X = MklConv2D(_, _, _)
2002   // Y = MklConv2DWithBias(_, _, _, _, _, _)
2003   // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
2004   //
2005   // For a graph such as shown above, note that 3rd argument of MklAdd contains
2006   // DummyMklTensor. Actually, it should be getting the Mkl metadata from
2007   // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
2008   // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
2009   // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
2010   // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
2011   // data edges (1st and 2nd arguments of MklAdd).
2012   bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
2013 
2014   // Functions specific to operators to copy attributes
2015   // We need operator-specific function to copy attributes because the framework
2016   // does not provide any generic function for it.
2017   // NOTE: names are alphabetically sorted.
2018   static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2019                            bool change_format = false);
2020   static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
2021                                            NodeBuilder* nb,
2022                                            bool change_format = false);
2023 
2024   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2025                             bool change_format = false);
2026   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
2027                                             NodeBuilder* nb,
2028                                             bool change_format = false);
2029   static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2030                                         const Node* orig_node2, NodeBuilder* nb,
2031                                         bool change_format = false);
2032   static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
2033                                              const Node* orig_node2,
2034                                              NodeBuilder* nb,
2035                                              bool change_format = false);
2036   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
2037                                        bool change_format = false);
2038   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2039                                   const std::vector<int32>& strides,
2040                                   const std::vector<int32>& dilations,
2041                                   bool change_format = false);
2042 
2043   static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2044                                                NodeBuilder* nb,
2045                                                bool change_format = false);
2046   static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2047       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2048   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2049                                bool change_format = false);
2050 
2051   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2052   // using node for original node 'orig_node' and return it in '*out'.
2053   // TODO(nhasabni) We should move this to mkl_util.h
2054   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2055                              const Node* orig_node);
2056   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2057                                    const Node* orig_node);
2058 };
2059 
2060 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2061 
2062 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2063 // We register it here so that we get a complete picture of all users of Mkl
2064 // nodes. Do not change the ordering of the Mkl passes.
2065 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2066     OptimizationPassRegistry::POST_PARTITIONING;
2067 #ifdef ENABLE_MKL
2068 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2069 #endif  // ENABLE_MKL
2070 
2071 //////////////////////////////////////////////////////////////////////////
2072 //           Helper functions for creating new node
2073 //////////////////////////////////////////////////////////////////////////
2074 
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2075 static void FillInputs(const Node* n,
2076                        gtl::InlinedVector<Node*, 4>* control_edges,
2077                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2078   control_edges->clear();
2079   for (const Edge* e : n->in_edges()) {
2080     if (e->IsControlEdge()) {
2081       control_edges->push_back(e->src());
2082     } else {
2083       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2084     }
2085   }
2086   std::sort(control_edges->begin(), control_edges->end());
2087 }
2088 
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2089 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2090     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2091     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2092   CHECK_LT(*input_idx, inputs.size());
2093   CHECK_GT(list_length, 0);
2094   DCHECK(output_nodes);
2095   output_nodes->reserve(list_length);
2096 
2097   while (list_length != 0) {
2098     CHECK_GT(list_length, 0);
2099     CHECK_LT(*input_idx, inputs.size());
2100     Node* n = inputs[*input_idx].first;
2101     int slot = inputs[*input_idx].second;
2102     // If input node 'n' is just producing a single tensor at
2103     // output slot 'slot' then we just add that single node.
2104     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2105     (*input_idx)++;
2106     list_length--;
2107   }
2108 }
2109 
2110 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2111 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2112                                                  Node** out,
2113                                                  const Node* orig_node) {
2114   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2115   // dummy Mkl tensor. 8 = 2*size_t.
2116   const DataType dt = DataTypeToEnum<uint8>::v();
2117   TensorProto proto;
2118   proto.set_dtype(dt);
2119   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2120   proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2121   TensorShape dummy_shape({8});
2122   dummy_shape.AsProto(proto.mutable_tensor_shape());
2123   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2124                   .Attr("value", proto)
2125                   .Attr("dtype", dt)
2126                   .Device(orig_node->def().device())  // We place this node on
2127                                                       // the same device as the
2128                                                       // device of the original
2129                                                       // node.
2130                   .Finalize(&**g, out));
2131   DCHECK(*out);  // Make sure we got a valid object before using it
2132 
2133   // If number of inputs to the original node is > 0, then we add
2134   // control dependency between 1st input (index 0) of the original node and
2135   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2136   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2137   // rewritten node. Adding control edge between 1st input of the original node
2138   // and the dummy Mkl node ensures that the dummy node is in the same frame
2139   // as the original node. Choosing 1st input is not necessary - any input of
2140   // the original node is fine because all the inputs of a node are always in
2141   // the same frame.
2142   if (orig_node->num_inputs() > 0) {
2143     Node* orig_input0 = nullptr;
2144     TF_CHECK_OK(
2145         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2146     auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2147     DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2148   }
2149 
2150   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2151 }
2152 
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2153 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2154     std::unique_ptr<Graph>* g, const Node* orig_node,
2155     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2156     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2157   CHECK_LT(*input_idx, inputs.size());
2158   CHECK_GT(list_length, 0);
2159   DCHECK(output_nodes);
2160   output_nodes->reserve(list_length);
2161 
2162   while (list_length != 0) {
2163     CHECK_GT(list_length, 0);
2164     CHECK_LT(*input_idx, inputs.size());
2165     Node* n = inputs[*input_idx].first;
2166     int slot = inputs[*input_idx].second;
2167     // If 'n' is producing a single tensor, then create a single Mkl tensor
2168     // node.
2169     Node* mkl_node = nullptr;
2170     int mkl_node_output_slot = 0;
2171     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2172                               &mkl_node_output_slot);
2173     output_nodes->push_back(
2174         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2175     (*input_idx)++;
2176     list_length--;
2177   }
2178 }
2179 
2180 // Get an input node that will feed Mkl tensor to the new
2181 // node that we are constructing. An input node could be (1) 'n'
2182 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2183 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2184 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2185     std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2186     int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2187   DCHECK(n);
2188   DCHECK(mkl_node);
2189   DCHECK(mkl_node_output_slot);
2190 
2191   // If this is an MKL op, then it will create extra output for MKL layout.
2192   DataType T;
2193   if (TryGetNodeAttr(n->def(), "T", &T) &&
2194       mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
2195     // If this is an MKL op, then it will generate an edge that will receive
2196     // Mkl tensor from a node.
2197     // output slot number for Mkl tensor would be N+slot number of TensorFlow
2198     // tensor, where N is total number of TensorFlow tensors.
2199     *mkl_node = n;
2200     *mkl_node_output_slot =
2201         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2202   } else {
2203     // If we have not visited the node and rewritten it, then we need
2204     // to create a dummy node that will feed a dummy Mkl tensor to this node.
2205     // DummyMklTensor node has no input and generates only 1 output
2206     // (dummy Mkl tensor) as output slot number 0.
2207     GetDummyMklTensorNode(g, mkl_node, orig_node);
2208     DCHECK(*mkl_node);
2209     *mkl_node_output_slot = 0;
2210   }
2211 }
2212 
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2213 int MklLayoutRewritePass::SetUpContiguousInputs(
2214     std::unique_ptr<Graph>* g,
2215     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2216     NodeBuilder* nb, const Node* old_node,
2217     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2218     bool are_workspace_tensors_available) {
2219   DCHECK(workspace_tensors);
2220   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2221 
2222   // TODO(nhasabni): Temporary solution to connect filter input of
2223   // BackpropInput with the converted filter from Conv2D.
2224   bool do_connect_conv2d_backprop_input_filter = false;
2225   Node* conv2d_node = nullptr;
2226   // Filter node is 2nd input (slot index 1) of Conv2D.
2227   int kConv2DFilterInputSlotIdx = 1;
2228   int kConv2DBackpropInputFilterInputSlotIdx = 1;
2229   int kConv2DFilterOutputSlotIdx = 1;
2230   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2231     // We need to find Conv2D node from Conv2DBackpropInput.
2232     // For that let's first find filter node that is 2nd input (slot 1)
2233     // of BackpropInput.
2234     Node* filter_node = nullptr;
2235     TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2236                                      &filter_node));
2237     DCHECK(filter_node);
2238 
2239     // Now check which nodes receive from filter_node. Filter feeds as
2240     // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2241     // _MklFusedConv2D.
2242     for (const Edge* e : filter_node->out_edges()) {
2243       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2244            e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2245            e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2246            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2247            e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2248           e->dst_input() == kConv2DFilterInputSlotIdx
2249           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2250         if (conv2d_node != nullptr) {
2251           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2252                   << " feeding multiple Conv2D nodes: "
2253                   << filter_node->DebugString();
2254           // We will not connect filter input of Conv2DBackpropInput
2255           // to be safe here.
2256           do_connect_conv2d_backprop_input_filter = false;
2257           break;
2258         } else {
2259           conv2d_node = e->dst();
2260           do_connect_conv2d_backprop_input_filter = true;
2261         }
2262       }
2263     }
2264   }
2265 
2266   // Number of input slots to original op
2267   // Input slots are represented by .Input() calls in REGISTER_OP.
2268   int old_node_input_slots = old_node->op_def().input_arg_size();
2269   int nn_slot_idx = 0;  // slot index for inputs of new node
2270 
2271   // Let's copy all inputs (TF tensors) of original node to new node.
2272   int iidx = 0;
2273   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2274     // An input slot could be a single tensor or a list. We need
2275     // to handle this case accordingly.
2276     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2277     if (ArgIsList(arg)) {
2278       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2279       int tensor_list_length = GetTensorListLength(arg, old_node);
2280       if (tensor_list_length != 0) {
2281         GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2282                                       tensor_list_length, &new_node_inputs);
2283       }
2284       nb->Input(new_node_inputs);
2285       nn_slot_idx++;
2286     } else {
2287       // Special case for connecting filter input of Conv2DBackpropInput
2288       if (do_connect_conv2d_backprop_input_filter &&
2289           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2290         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2291       } else {
2292         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2293       }
2294       iidx++;
2295       nn_slot_idx++;
2296     }
2297   }
2298 
2299   // If workspace tensors are available for this op and we are using
2300   // contiguous ordering then we need to add Tensorflow tensor for
2301   // workspace here because Tensorflow tensor for workspace is the
2302   // last tensor in the list of Tensorflow tensors.
2303   if (are_workspace_tensors_available) {
2304     CHECK_EQ(workspace_tensors->size(), 2);
2305     // Tensorflow tensor
2306     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2307     nn_slot_idx++;
2308   }
2309 
2310   // Let's now setup all Mkl inputs to a new node.
2311   // Number of Mkl inputs must be same as number of TF inputs.
2312   iidx = 0;
2313   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2314     // An input slot could be a single tensor or a list. We need
2315     // to handle this case accordingly.
2316     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2317     if (ArgIsList(arg)) {
2318       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2319       int tensor_list_length = GetTensorListLength(arg, old_node);
2320       if (tensor_list_length != 0) {
2321         GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2322                                        tensor_list_length, &new_node_inputs);
2323       }
2324       nb->Input(new_node_inputs);
2325       nn_slot_idx++;
2326     } else {
2327       Node* mkl_node = nullptr;
2328       int mkl_node_output_slot = 0;
2329       // Special case for connecting filter input of Conv2DBackpropInput
2330       if (do_connect_conv2d_backprop_input_filter &&
2331           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2332         GetNodeProducingMklTensor(g, old_node, conv2d_node,
2333                                   kConv2DFilterOutputSlotIdx, &mkl_node,
2334                                   &mkl_node_output_slot);
2335       } else {
2336         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2337                                   old_node_inputs[iidx].second, &mkl_node,
2338                                   &mkl_node_output_slot);
2339       }
2340       nb->Input(mkl_node, mkl_node_output_slot);
2341       iidx++;
2342       nn_slot_idx++;
2343     }
2344   }
2345 
2346   // If workspace tensors are available for this op and we are using
2347   // contiguous ordering then we need to add Mkl tensor for
2348   // workspace here because Mkl tensor for workspace is the
2349   // last tensor in the list of Mkl tensors.
2350   if (are_workspace_tensors_available) {
2351     CHECK_EQ(workspace_tensors->size(), 2);
2352     // Mkl tensor
2353     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2354     nn_slot_idx++;
2355   }
2356 
2357   return nn_slot_idx;
2358 }
2359 
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2360 Status MklLayoutRewritePass::SetUpInputs(
2361     std::unique_ptr<Graph>* g,
2362     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2363     NodeBuilder* nb, const Node* old_node) {
2364   // Let's check if we need to add workspace tensors for this node.
2365   // We add workspace edge only for MaxPool, LRN and BatchNorm.
2366   std::vector<NodeBuilder::NodeOut> workspace_tensors;
2367   bool are_workspace_tensors_available = false;
2368 
2369   // Avoid workspace check for QuantizedConv2D and the fused
2370   // Ops as they don't have attribute: "T".
2371   std::vector<string> quant_ops{
2372       "Dequantize",
2373       "QuantizeV2",
2374       "QuantizedConv2D",
2375       "QuantizedConv2DWithBias",
2376       "QuantizedConv2DAndRelu",
2377       "QuantizedConv2DWithBiasAndRelu",
2378       "QuantizedConv2DWithBiasSumAndRelu",
2379       "QuantizedConv2DPerChannel",
2380       "QuantizedConv2DAndRequantize",
2381       "QuantizedConv2DWithBiasAndRequantize",
2382       "QuantizedConv2DAndReluAndRequantize",
2383       "QuantizedConv2DWithBiasAndReluAndRequantize",
2384       "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2385       "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2386       "QuantizedMatMulWithBias",
2387       "QuantizedMatMulWithBiasAndRequantize",
2388       "QuantizedMatMulWithBiasAndDequantize",
2389       "QuantizedMatMulWithBiasAndRelu",
2390       "QuantizedMatMulWithBiasAndReluAndRequantize",
2391       "QuantizedDepthwiseConv2D",
2392       "QuantizedDepthwiseConv2DWithBias",
2393       "QuantizedDepthwiseConv2DWithBiasAndRelu",
2394       "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2395   bool should_check_workspace =
2396       std::find(std::begin(quant_ops), std::end(quant_ops),
2397                 old_node->type_string()) == std::end(quant_ops);
2398   if (should_check_workspace)
2399     AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2400                              &are_workspace_tensors_available);
2401 
2402   int new_node_input_slots = 0;
2403   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2404     // TODO(nhasabni): implement this function just for same of completion.
2405     // We do not use interleaved ordering right now.
2406     return Status(
2407         error::Code::UNIMPLEMENTED,
2408         "Interleaved ordering of tensors is currently not supported.");
2409   } else {
2410     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2411     new_node_input_slots = SetUpContiguousInputs(
2412         g, old_node_inputs, nb, old_node, &workspace_tensors,
2413         are_workspace_tensors_available);
2414   }
2415 
2416   // Sanity check
2417   int old_node_input_slots = old_node->op_def().input_arg_size();
2418   if (!are_workspace_tensors_available) {
2419     // If we are not adding workspace tensors for this op, then the total
2420     // number of input slots to the new node _must_ be 2 times the number
2421     // of input slots to the original node: N original Tensorflow tensors and
2422     // N for Mkl tensors corresponding to each Tensorflow tensors.
2423     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2424   } else {
2425     // If we are adding workspace tensors for this op, then the total
2426     // The total number of input slots to new node _must_ be 2 times the number
2427     // of input slots to the original node: N original Tensorflow tensors and
2428     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2429     // (for workspace Tensorflow tensor and workspace Mkl tensor).
2430     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2431   }
2432 
2433   return Status::OK();
2434 }
2435 
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2436 Status MklLayoutRewritePass::CopyInputs(
2437     const Node* old_node,
2438     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2439     NodeBuilder* nb) {
2440   // Number of input slots to old node
2441   // Input slots are represented by .Input() calls in REGISTER_OP.
2442   int old_node_input_slots = old_node->op_def().input_arg_size();
2443   // Actual number of inputs can be greater than or equal to number
2444   // of Input slots because inputs of type list could be unfolded.
2445   auto old_node_input_size = old_node_inputs.size();
2446   DCHECK_GE(old_node_input_size, old_node_input_slots);
2447 
2448   // Let's copy all inputs of old node to new node.
2449   int iidx = 0;
2450   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2451     // An input slot could be a single tensor or a list. We need
2452     // to handle this case accordingly.
2453     DCHECK_LT(iidx, old_node_input_size);
2454     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2455     if (ArgIsList(arg)) {
2456       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2457       int N = GetTensorListLength(arg, old_node);
2458       if (N != 0) {
2459         GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2460                                       &new_node_inputs);
2461       }
2462       nb->Input(new_node_inputs);
2463     } else {
2464       nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2465       iidx++;
2466     }
2467   }
2468   return Status::OK();
2469 }
2470 
2471 //////////////////////////////////////////////////////////////////////////
2472 //           Helper functions related to workspace pass
2473 //////////////////////////////////////////////////////////////////////////
2474 
2475 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2476 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2477     std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2478   // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2479   // workspace tensor.
2480   GetDummyMklTensorNode(g, out, orig_node);
2481 }
2482 
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2483 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2484     std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2485     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2486   bool workspace_edge_added = false;  // Default initializer
2487   DCHECK(are_ws_tensors_added);
2488   *are_ws_tensors_added = false;  // Default initializer
2489 
2490   DataType T;
2491   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2492   for (auto ws : wsinfo_) {
2493     if (orig_node->type_string() == ws.fwd_op &&
2494         mkl_op_registry::IsMklOp(
2495             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2496       // If this op is a fwd op, then we need to check if there is an
2497       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2498       // an edge, then we just add an attribute on this node for setting
2499       // workspace_passed to true. We don't add actual workspace edge
2500       // in this node. Actual workspace edge gets added in the backward
2501       // op for this node.
2502       for (const Edge* e : orig_node->out_edges()) {
2503         if (e->src_output() == ws.fwd_slot &&
2504             e->dst()->type_string() == ws.bwd_op &&
2505             e->dst_input() == ws.bwd_slot) {
2506           nb->Attr("workspace_enabled", true);
2507           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2508                   << orig_node->type_string();
2509           workspace_edge_added = true;
2510           // We found the edge that we were looking for, so break.
2511           break;
2512         }
2513       }
2514 
2515       if (!workspace_edge_added) {
2516         // If we are here, then we did not find backward operator for this
2517         // node.
2518         nb->Attr("workspace_enabled", false);
2519       }
2520     } else if (orig_node->type_string() == ws.bwd_op &&
2521                mkl_op_registry::IsMklOp(
2522                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
2523                    T)) {
2524       // If this op is a bwd op, then we need to add workspace edge and
2525       // it's Mkl tensor edge between its corresponding fwd op and this
2526       // op. Corresponding fwd op is specified in 'fwd_op' field of
2527       // workspace info. fwd_slot and bwd_slot in workspace info specify
2528       // an edge between which slots connect forward and backward op.
2529       // Once all these criteria match, we add a workspace edge between
2530       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2531       // determined by interleaved/contiguous ordering. Function
2532       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2533       // from the location of the Tensorflow tensor.
2534       for (const Edge* e : orig_node->in_edges()) {
2535         if (e->src_output() == ws.fwd_slot &&
2536             // We would have rewritten the forward op, so we need to use
2537             // GetMklOpName call to get its Mkl name.
2538             e->src()->type_string() ==
2539                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2540             e->dst_input() == ws.bwd_slot) {
2541           nb->Attr("workspace_enabled", true);
2542           DCHECK(ws_tensors);
2543           // Add workspace edge between fwd op and bwd op.
2544           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2545           // Check if we are running in native format mode. If so,
2546           // we don't need to have an Mkl metadata tensor for the workspace.
2547           if (!NativeFormatEnabled()) {
2548             // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2549             ws_tensors->push_back(NodeBuilder::NodeOut(
2550                 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2551                                                    e->src()->num_outputs())));
2552           }
2553           *are_ws_tensors_added = true;
2554           // In terms of input ordering, we add these calls to add Input
2555           // here because workspace edge (and its Mkl tensor) is the last
2556           // edge in the fwdop and bwdop. So all inputs before workspace
2557           // tensor have been added by SetUpInputs function.
2558           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2559                   << orig_node->type_string();
2560           workspace_edge_added = true;
2561           // We found the edge that we were looking for, so break.
2562           break;
2563         }
2564       }
2565 
2566       // If we are here means we did not find fwd op that feeds to this
2567       // bwd op. So in this case, we need to generate dummy tensors for
2568       // workspace input and Mkl tensor for workspace, and set
2569       // workspace_enabled to false.
2570       if (!workspace_edge_added) {
2571         nb->Attr("workspace_enabled", false);
2572         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
2573         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
2574         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2575         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2576         DCHECK(dmt_ws);
2577         DCHECK(dmt_mkl_ws);
2578         DCHECK(ws_tensors);
2579         // We add dummy tensor as workspace tensor.
2580         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2581         // We add dummy tensor as Mkl tensor for workspace tensor.
2582         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2583         *are_ws_tensors_added = true;
2584         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2585                 << orig_node->type_string();
2586       }
2587     } else {
2588       // If this node does not match any workspace info, then we do not
2589       // do anything special for workspace propagation for it.
2590     }
2591   }
2592 }
2593 
2594 //////////////////////////////////////////////////////////////////////////
2595 // Op-specific functions to copy attributes from old node to new node
2596 //////////////////////////////////////////////////////////////////////////
2597 
2598 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2599 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2600                                         bool change_format) {
2601   string name;
2602   AttrSlice attr_list(orig_node->def());
2603 
2604   auto iter = attr_list.begin();
2605   while (iter != attr_list.end()) {
2606     name = iter->first;
2607     auto attr = iter->second;
2608     nb->Attr(name, attr);
2609     ++iter;
2610   }
2611 }
2612 
2613 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2614 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2615                                                         NodeBuilder* nb,
2616                                                         bool change_format) {
2617   CopyAttrsAll(orig_node, nb, change_format);
2618 
2619   // Check and set filter attribute.
2620   Node* filter_node = nullptr;
2621   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2622   nb->Attr("is_filter_const", filter_node->IsConstant());
2623 }
2624 
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2625 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2626                                                          NodeBuilder* nb,
2627                                                          bool change_format) {
2628   CopyAttrsConv(orig_node, nb, change_format);
2629 
2630   // Check and set filter attribute.
2631   Node* filter_node = nullptr;
2632   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2633   nb->Attr("is_filter_const", filter_node->IsConstant());
2634 }
2635 
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2636 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2637                                          bool change_format) {
2638   DataType T;
2639   string padding;
2640   std::vector<int32> strides;
2641   std::vector<int32> dilations;
2642   std::vector<int32> explicit_paddings;
2643 
2644   // Get all attributes from old node.
2645   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2646   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2647   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2648   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2649 
2650   // Check `explicit_paddings` first because some Conv ops don't have
2651   // this attribute.
2652   if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2653                      &explicit_paddings) &&
2654       !explicit_paddings.empty()) {
2655     nb->Attr("explicit_paddings", explicit_paddings);
2656   }
2657 
2658   // Add attributes to new node.
2659   nb->Attr("T", T);
2660   nb->Attr("padding", padding);
2661 
2662   // Add attributes related to `data_format`.
2663   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2664 }
2665 
2666 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2667 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2668                                                      const Node* orig_node2,
2669                                                      NodeBuilder* nb,
2670                                                      bool change_format) {
2671   DataType Tpaddings;
2672   DataType T;
2673   string data_format;
2674   string padding;
2675   std::vector<int32> strides;
2676   std::vector<int32> dilations;
2677   bool use_cudnn_on_gpu;
2678 
2679   // Get all attributes from old node 1.
2680   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2681   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2682   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2683   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2684   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2685   TF_CHECK_OK(
2686       GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2687   // Get all attributes from old node 2.
2688   TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2689 
2690   // Add attributes to new node.
2691   nb->Attr("T", T);
2692   nb->Attr("strides", strides);
2693   nb->Attr("dilations", dilations);
2694   nb->Attr("padding", padding);
2695   nb->Attr("data_format", data_format);
2696   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2697   nb->Attr("Tpaddings", Tpaddings);
2698 }
2699 
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2700 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2701     const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2702     bool change_format) {
2703   DataType T;
2704   int num_args;
2705   string data_format;
2706   string padding;
2707   std::vector<int32> strides;
2708   std::vector<int32> dilations;
2709   float epsilon;
2710   std::vector<string> fused_ops;
2711   DataType Tpaddings;
2712   float leakyrelu_alpha;
2713 
2714   // Get all attributes from old node.
2715   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2716   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2717   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2718   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2719   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2720   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2721   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2722   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2723   TF_CHECK_OK(
2724       GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2725   TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2726 
2727   // Add attributes to new node.
2728   nb->Attr("T", T);
2729   nb->Attr("num_args", num_args);
2730   nb->Attr("strides", strides);
2731   nb->Attr("padding", padding);
2732   nb->Attr("data_format", data_format);
2733   nb->Attr("dilations", dilations);
2734   nb->Attr("epsilon", epsilon);
2735   nb->Attr("Tpaddings", Tpaddings);
2736   nb->Attr("fused_ops", fused_ops);
2737   nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2738 }
2739 
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2740 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2741                                                     NodeBuilder* nb,
2742                                                     bool change_format) {
2743   DataType Tinput, Tfilter, out_type;
2744   string padding;
2745   string data_format("NHWC");
2746   std::vector<int32> strides, dilations, padding_list;
2747   bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2748 
2749   // Get all attributes from old node.
2750   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2751   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2752   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2753   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2754   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2755   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2756   if (has_padding_list) {
2757     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2758   }
2759 
2760   Node* filter_node = nullptr;
2761   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2762 
2763   // Add attributes to new node.
2764   nb->Attr("Tinput", Tinput);
2765   nb->Attr("Tfilter", Tfilter);
2766   nb->Attr("out_type", out_type);
2767   nb->Attr("padding", padding);
2768   nb->Attr("is_filter_const", filter_node->IsConstant());
2769   nb->Attr("strides", strides);
2770   nb->Attr("dilations", dilations);
2771   nb->Attr("T", out_type);  // added "T" for facilitating MklToTf conversion.
2772   nb->Attr("data_format", data_format);
2773   if (has_padding_list) {
2774     nb->Attr("padding_list", padding_list);
2775   }
2776 
2777   // Requantization attr Tbias.
2778   DataType Tbias;
2779   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2780   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2781 }
2782 
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2783 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2784     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2785   DataType T1, T2, Toutput;
2786 
2787   // Get all attributes from old node.
2788   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2789   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2790   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2791 
2792   // Add attributes to new node.
2793   nb->Attr("T1", T1);
2794   nb->Attr("T2", T2);
2795   nb->Attr("Toutput", Toutput);
2796   nb->Attr("T", T1);  // added "T" for facilitating MklToTf conversion.
2797 
2798   // Requantization attr Tbias
2799   DataType Tbias;
2800   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2801   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2802 }
2803 
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2804 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2805     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2806   DataType T1, T2, Toutput;
2807 
2808   // Get all attributes from old node.
2809   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2810   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2811   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2812 
2813   Node* weight_node = nullptr;
2814   TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2815 
2816   // Add attributes to new node.
2817   nb->Attr("T1", T1);
2818   nb->Attr("T2", T2);
2819   nb->Attr("Toutput", Toutput);
2820   nb->Attr("is_weight_const", weight_node->IsConstant());
2821   nb->Attr("T", Toutput);  // added "T" for facilitating MklToTf conversion.
2822 
2823   // Requantization attr Tbias
2824   DataType Tbias;
2825   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2826   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2827 }
2828 
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2829 void MklLayoutRewritePass::CopyFormatAttrsConv(
2830     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2831     const std::vector<int32>& dilations, bool change_format) {
2832   string data_format;
2833 
2834   if (!change_format) {
2835     nb->Attr("strides", strides);
2836     nb->Attr("dilations", dilations);
2837 
2838     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2839     nb->Attr("data_format", data_format);
2840   } else {
2841     std::vector<int32> new_strides;
2842     std::vector<int32> new_dilations;
2843     if (strides.size() == 5) {
2844       // `strides` and `dilations` also need to be changed according to
2845       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2846       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2847                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2848                      strides[NDHWC::dim::W]};
2849 
2850       new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2851                        dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2852                        dilations[NDHWC::dim::W]};
2853     } else {
2854       // `strides` and `dilations` also need to be changed according to
2855       // `data_format`. In this case, from `NHWC` to `NCHW`.
2856 
2857       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2858                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2859 
2860       new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2861                        dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2862     }
2863     nb->Attr("strides", new_strides);
2864     nb->Attr("dilations", new_dilations);
2865   }
2866 }
2867 
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2868 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2869                                             NodeBuilder* nb,
2870                                             bool change_format) {
2871   DataType T;
2872   string data_format;
2873   string padding;
2874   std::vector<int32> ksize, strides;
2875 
2876   // Get all attributes from old node.
2877   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2878   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2879   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2880   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2881   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2882 
2883   // Add attributes to new node.
2884   nb->Attr("T", T);
2885   nb->Attr("padding", padding);
2886 
2887   if (!change_format) {
2888     nb->Attr("strides", strides);
2889     nb->Attr("ksize", ksize);
2890 
2891     nb->Attr("data_format", data_format);
2892   } else {
2893     std::vector<int32> new_strides;
2894     std::vector<int32> new_ksize;
2895     if (strides.size() == 5) {
2896       DCHECK(data_format == "NCDHW");
2897       // `strides` and `ksize` also need to be changed according to
2898       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2899       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2900                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2901                      strides[NDHWC::dim::W]};
2902 
2903       new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2904                    ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2905                    ksize[NDHWC::dim::W]};
2906 
2907     } else {
2908       // `strides` and `ksize` also need to be changed according to
2909       // `data_format`. In this case, from `NHWC` to `NCHW`.
2910       DCHECK(data_format == "NCHW");
2911       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2912                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2913 
2914       new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2915                    ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2916     }
2917     nb->Attr("strides", new_strides);
2918     nb->Attr("ksize", new_ksize);
2919   }
2920 }
2921 
2922 //////////////////////////////////////////////////////////////////////////
2923 //           Helper functions related to node merge pass
2924 //////////////////////////////////////////////////////////////////////////
2925 
CheckForNodeMerge(const Node * a) const2926 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2927   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2928   // once we support BiasAddGrad as Mkl layer.
2929 
2930   // Search for all matching mergeinfo.
2931   // We allow more than one match for extensibility.
2932   std::vector<const MergeInfo*> matching_mi;
2933   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2934     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2935       matching_mi.push_back(&*mi);
2936     }
2937   }
2938 
2939   for (const MergeInfo* mi : matching_mi) {
2940     // Get the operand with which 'a' can be merged.
2941     Node* b = nullptr;
2942     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2943       continue;
2944     }
2945 
2946     // Get the control edges and input of node
2947     const int N_in = a->num_inputs();
2948     gtl::InlinedVector<Node*, 4> a_control_edges;
2949     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2950     FillInputs(a, &a_control_edges, &a_in);
2951 
2952     const int B_in = b->num_inputs();
2953     gtl::InlinedVector<Node*, 4> b_control_edges;
2954     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2955     FillInputs(b, &b_control_edges, &b_in);
2956 
2957     // Shouldn't merge if a and b have different control edges.
2958     if (a_control_edges != b_control_edges) {
2959       continue;
2960     } else {
2961       // We found a match.
2962       return b;
2963     }
2964   }
2965 
2966   return nullptr;
2967 }
2968 
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2969 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2970                                                     Node* m, Node* n) {
2971   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2972              n->type_string() == csinfo_.conv2d)) ||
2973                ((n->type_string() == csinfo_.bias_add &&
2974                  m->type_string() == csinfo_.conv2d)),
2975            true);
2976 
2977   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2978   // BiasAdd is successor node, and Conv2D predecessor node.
2979   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2980   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2981 
2982   // 1. Get all attributes from input nodes.
2983   DataType T_pred, T_succ;
2984   string padding;
2985   std::vector<int32> strides;
2986   std::vector<int32> dilations;
2987   string data_format_pred, data_format_succ;
2988   bool use_cudnn_on_gpu;
2989   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2990   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2991   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2992   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2993   TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2994   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2995   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2996   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2997   // We check to ensure that data formats of both succ and pred are same.
2998   // We expect them to be same, so we can enforce this as assert.
2999   // But assert can be too strict, so we enforce this as a check.
3000   // If the check fails, then we do not merge two nodes.
3001   // We also do same check for devices.
3002   if (data_format_pred != data_format_succ || T_pred != T_succ ||
3003       pred->assigned_device_name() != succ->assigned_device_name() ||
3004       pred->def().device() != succ->def().device()) {
3005     return Status(error::Code::INVALID_ARGUMENT,
3006                   "data_format or T attribute or devices of Conv2D and "
3007                   "BiasAdd do not match. Will skip node merge optimization");
3008   }
3009 
3010   const int succ_num = succ->num_inputs();
3011   gtl::InlinedVector<Node*, 4> succ_control_edges;
3012   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3013   FillInputs(succ, &succ_control_edges, &succ_in);
3014 
3015   const int pred_num = pred->num_inputs();
3016   gtl::InlinedVector<Node*, 4> pred_control_edges;
3017   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3018   FillInputs(pred, &pred_control_edges, &pred_in);
3019 
3020   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
3021   // not expecting output of Conv2D). If this is not the case, then we cannot
3022   // merge Conv2D with BiasAdd.
3023   const int kFirstOutputSlot = 0;
3024   for (const Edge* e : pred->out_edges()) {
3025     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3026       return Status(error::Code::INVALID_ARGUMENT,
3027                     "Conv2D does not feed to BiasAdd, or "
3028                     "it feeds BiasAdd but has multiple outputs. "
3029                     "Will skip node merge optimization");
3030     }
3031   }
3032 
3033   // 2. Get inputs from both the nodes.
3034   // Find the 2 inputs from the conv and the bias from the add Bias.
3035   // Get operand 0, 1 of conv2D.
3036   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
3037   // Get operand 1 of add_bias
3038   // BiasAdd must have 2 inputs: Conv, bias
3039   CHECK_EQ(succ->in_edges().size(), 2);
3040 
3041   // We will use the node name of BiasAdd as the name of new node
3042   // Build new node. We use same name as original node, but change the op
3043   // name.
3044   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3045   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
3046   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3047   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
3048   // In1 of BiasAdd is same as output of Conv2D.
3049   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
3050 
3051   // Copy attributes from Conv2D to Conv2DWithBias.
3052   CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3053 
3054   // Copy the device assigned to old node to new node.
3055   nb.Device(succ->def().device());
3056 
3057   // Create node.
3058   Node* new_node;
3059   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3060 
3061   // In the following code of this function, an unsorted set is used to make
3062   // sure no duplicated edges be added into the new node. Therefore, we can
3063   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3064   // check in the routine.
3065 
3066   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3067   // node are already copied in BuildNode. We handle control edges now.
3068   std::unordered_set<Node*> unique_node;
3069   for (const Edge* e : pred->in_edges()) {
3070     if (e->IsControlEdge()) {
3071       auto result = unique_node.insert(e->src());
3072       if (result.second) {
3073         (*g)->AddControlEdge(e->src(), new_node, true);
3074       }
3075     }
3076   }
3077   unique_node.clear();
3078 
3079   for (const Edge* e : succ->in_edges()) {
3080     if (e->IsControlEdge()) {
3081       auto result = unique_node.insert(e->src());
3082       if (result.second) {
3083         (*g)->AddControlEdge(e->src(), new_node, true);
3084       }
3085     }
3086   }
3087   unique_node.clear();
3088 
3089   // Incoming edges are fixed, we will fix the outgoing edges now.
3090   // First, we will fix outgoing control edges from 'pred' node.
3091   for (const Edge* e : pred->out_edges()) {
3092     if (e->IsControlEdge()) {
3093       auto result = unique_node.insert(e->dst());
3094       if (result.second) {
3095         (*g)->AddControlEdge(new_node, e->dst(), true);
3096       }
3097     }
3098   }
3099   unique_node.clear();
3100 
3101   // Second, we will fix outgoing control and data edges from 'succ' node.
3102   for (const Edge* e : succ->out_edges()) {
3103     if (e->IsControlEdge()) {
3104       auto result = unique_node.insert(e->dst());
3105       if (result.second) {
3106         (*g)->AddControlEdge(new_node, e->dst(), true);
3107       }
3108     } else {
3109       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3110       // output (at slot 0).
3111       const int kConv2DWithBiasOutputSlot = 0;
3112       auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3113                                     e->dst(), e->dst_input());
3114       DCHECK(new_edge);
3115     }
3116   }
3117 
3118   // Copy device assigned to old node to new node.
3119   // It's ok to use pred or succ as we have enforced a check that
3120   // both have same device assigned.
3121   new_node->set_assigned_device_name(pred->assigned_device_name());
3122 
3123   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3124           << ", and node: " << succ->DebugString()
3125           << ", into node:" << new_node->DebugString();
3126 
3127   (*g)->RemoveNode(succ);
3128   (*g)->RemoveNode(pred);
3129 
3130   return Status::OK();
3131 }
3132 
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3133 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3134                                                 Node* m, Node* n) {
3135   DCHECK((m->type_string() == csinfo_.pad &&
3136           (n->type_string() == csinfo_.conv2d ||
3137            n->type_string() == csinfo_.fused_conv2d)) ||
3138          (n->type_string() == csinfo_.pad &&
3139           (m->type_string() == csinfo_.conv2d ||
3140            m->type_string() == csinfo_.fused_conv2d)));
3141 
3142   bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3143                          m->type_string() == csinfo_.fused_conv2d;
3144   // Conv2D is successor node, and Pad predecessor node.
3145   Node* pred = m->type_string() == csinfo_.pad ? m : n;
3146   Node* succ = m->type_string() == csinfo_.pad ? n : m;
3147 
3148   // 1. Get all attributes from input nodes.
3149   DataType T_pred, T_succ;
3150   string padding;
3151   std::vector<int32> strides;
3152   std::vector<int32> dilations;
3153   string data_format_pred, data_format_succ;
3154 
3155   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3156   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3157   TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3158   TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3159   TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3160   // Check if the devices of both succ and pred are the same.
3161   // Assert is not used because it can be too strict.
3162   // Don't need to check for data formats because it is not available in Pad.
3163   if (T_pred != T_succ ||
3164       pred->assigned_device_name() != succ->assigned_device_name() ||
3165       pred->def().device() != succ->def().device()) {
3166     return Status(error::Code::INVALID_ARGUMENT,
3167                   "T attribute or devices of Conv2D and "
3168                   "Pad do not match. Will skip node merge optimization");
3169   }
3170 
3171   const int succ_num = succ->num_inputs();
3172   gtl::InlinedVector<Node*, 4> succ_control_edges;
3173   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3174   FillInputs(succ, &succ_control_edges, &succ_in);
3175 
3176   const int pred_num = pred->num_inputs();
3177   gtl::InlinedVector<Node*, 4> pred_control_edges;
3178   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3179   FillInputs(pred, &pred_control_edges, &pred_in);
3180 
3181   // We need to ensure that Pad only feeds to Conv2D (some other operator is
3182   // not expecting output of Pad). If this is not the case, then we cannot
3183   // merge Conv2D with Pad.
3184   const int kFirstOutputSlot = 0;
3185   for (const Edge* e : pred->out_edges()) {
3186     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3187       return Status(error::Code::INVALID_ARGUMENT,
3188                     "Pad does not feed to Conv2D, or "
3189                     "it feeds Conv2D but has multiple outputs. "
3190                     "Will skip node merge optimization");
3191     }
3192   }
3193 
3194   // 2. Get inputs from both the nodes.
3195 
3196   // Pad must have 2 data inputs: "input" and paddings.
3197   int PadDataInputEdges = 0;
3198   for (const Edge* e : pred->in_edges()) {
3199     if (!e->IsControlEdge()) {
3200       PadDataInputEdges++;
3201     }
3202   }
3203   DCHECK_EQ(PadDataInputEdges, 2);
3204 
3205   // Conv2D must have 2 data inputs: Pad output and Filter
3206   // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3207   int ConvDataInputEdges = 0;
3208   for (const Edge* e : succ->in_edges()) {
3209     if (!e->IsControlEdge()) {
3210       ConvDataInputEdges++;
3211     }
3212   }
3213 
3214   DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3215 
3216   // We will use the node name of Conv2D as the name of new node
3217   // Build new node. We use same name as original node, but change the op
3218   // name.
3219 
3220   NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3221                                                : csinfo_.pad_with_conv2d);
3222   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 (input data)  of Pad
3223   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3224   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 (filter) of conv2d
3225   // In1 of Conv2D is same as output of Pad.
3226   // Thus, only need to add In2 of Conv2D
3227 
3228   if (is_fused_conv2d) {
3229     // FusedConv2D has one additional input, args
3230     std::vector<NodeBuilder::NodeOut> args;
3231     args.emplace_back(succ_in[2].first, succ_in[2].second);
3232     nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3233         args});                                     // In3 (args) of FusedConv2D
3234     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3235     // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3236     CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3237                                    const_cast<const Node*>(pred), &nb);
3238   } else {
3239     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3240     // Copy attributes from Pad and conv2D to PadWithConv2D.
3241     CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3242                               const_cast<const Node*>(pred), &nb);
3243   }
3244 
3245   // Copy the device assigned to old node to new node.
3246   nb.Device(succ->def().device());
3247 
3248   // Create node.
3249   Node* new_node;
3250   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3251   // No need to check if new_node is null because it will be null only when
3252   // Finalize fails.
3253 
3254   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3255   // node are already copied in BuildNode.
3256   // We handle control edges now.
3257   for (const Edge* e : pred->in_edges()) {
3258     if (e->IsControlEdge()) {
3259       // Don't allow duplicate edge
3260       (*g)->AddControlEdge(e->src(), new_node, false);
3261     }
3262   }
3263   for (const Edge* e : succ->in_edges()) {
3264     if (e->IsControlEdge()) {
3265       // Don't allow duplicate edge
3266       (*g)->AddControlEdge(e->src(), new_node, false);
3267     }
3268   }
3269 
3270   // Incoming edges are fixed, we will fix the outgoing edges now.
3271   // First, we will fix outgoing control edges from 'pred' node.
3272   for (const Edge* e : pred->out_edges()) {
3273     if (e->IsControlEdge()) {
3274       // Don't allow duplicate edge
3275       (*g)->AddControlEdge(new_node, e->dst(), false);
3276     }
3277   }
3278 
3279   // Second, we will fix outgoing control and data edges from 'succ' node.
3280   for (const Edge* e : succ->out_edges()) {
3281     if (e->IsControlEdge()) {
3282       // Allow duplicate while adding control edge as it would fail (return
3283       // NULL) if we try to add duplicate edge.
3284       (*g)->AddControlEdge(new_node, e->dst(), false);
3285     } else {
3286       // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3287       // output (at slot 0).
3288       const int kPadWithConv2DOutputSlot = 0;
3289       (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3290                     e->dst_input());
3291     }
3292   }
3293 
3294   // Copy device assigned to old node to new node.
3295   // It's ok to use pred or succ as we have enforced a check that
3296   // both have same device assigned.
3297   new_node->set_assigned_device_name(pred->assigned_device_name());
3298 
3299   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3300           << ", and node: " << succ->DebugString()
3301           << ", into node:" << new_node->DebugString();
3302 
3303   (*g)->RemoveNode(succ);
3304   (*g)->RemoveNode(pred);
3305 
3306   return Status::OK();
3307 }
3308 
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3309 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3310     std::unique_ptr<Graph>* g, Node* m, Node* n) {
3311   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3312              n->type_string() == csinfo_.conv2d_grad_filter)) ||
3313                ((n->type_string() == csinfo_.bias_add_grad &&
3314                  m->type_string() == csinfo_.conv2d_grad_filter)),
3315            true);
3316 
3317   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3318   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3319   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3320 
3321   // Sanity check for attributes from input nodes.
3322   DataType T_b, T_f;
3323   string data_format_b, data_format_f;
3324   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3325   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3326   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3327   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3328   if (data_format_b != data_format_f || T_b != T_f ||
3329       badd->assigned_device_name() != fltr->assigned_device_name() ||
3330       badd->def().device() != fltr->def().device()) {
3331     return Status(error::Code::INVALID_ARGUMENT,
3332                   "data_format or T attribute or devices of "
3333                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
3334                   "Will skip node merge optimization");
3335   }
3336 
3337   // We will use the node name of Conv2DBackpropFilter as the name of new node.
3338   // This is because BackpropFilterWithBias is going to emit bias output also.
3339   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3340   // Since Conv2DBackpropFilterWithBias has same number of inputs as
3341   // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3342   // to copy any data input of BiasAddGrad because that input also goes to
3343   // Conv2DBackpropFilter.
3344   const int fltr_ins = fltr->num_inputs();
3345   gtl::InlinedVector<Node*, 4> fltr_control_edges;
3346   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3347   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3348   for (int idx = 0; idx < fltr_ins; idx++) {
3349     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3350   }
3351 
3352   // Copy attributes from Conv2DBackpropFilter.
3353   CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3354 
3355   // Copy the device assigned to old node to new node.
3356   nb.Device(fltr->def().device());
3357 
3358   // Create node.
3359   Node* new_node;
3360   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3361 
3362   // In the following code of this function, an unsorted set is used to make
3363   // sure no duplicated edges be added into the new node. Therefore, we can
3364   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3365   // check in the routine.
3366 
3367   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3368   // new 'new_node' node are already copied in BuildNode. We handle control
3369   // edges now.
3370   std::unordered_set<Node*> unique_node;
3371   for (const Edge* e : badd->in_edges()) {
3372     if (e->IsControlEdge()) {
3373       auto result = unique_node.insert(e->src());
3374       if (result.second) {
3375         (*g)->AddControlEdge(e->src(), new_node, true);
3376       }
3377     }
3378   }
3379   unique_node.clear();
3380   for (const Edge* e : fltr->in_edges()) {
3381     if (e->IsControlEdge()) {
3382       auto result = unique_node.insert(e->src());
3383       if (result.second) {
3384         (*g)->AddControlEdge(e->src(), new_node, true);
3385       }
3386     }
3387   }
3388   unique_node.clear();
3389 
3390   // Incoming edges are fixed, we will fix the outgoing edges now.
3391   // First, we will fix outgoing control edges from 'badd' node.
3392   // Conv2DBackpropFilter has 1 output -- filter_grad.
3393   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3394   // bias_grad. But filter_grad is at same slot number (0) in both the
3395   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3396   // it is at slot number 0 in BiasAddGrad.
3397   const int kMergedNodeFilterGradOutputIdx = 0;
3398   const int kMergedNodeBiasGradOutputIdx = 1;
3399 
3400   for (const Edge* e : badd->out_edges()) {
3401     if (e->IsControlEdge()) {
3402       auto result = unique_node.insert(e->dst());
3403       if (result.second) {
3404         (*g)->AddControlEdge(new_node, e->dst(), true);
3405       }
3406     } else {
3407       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3408                                     e->dst(), e->dst_input());
3409       DCHECK(new_edge);
3410     }
3411   }
3412   unique_node.clear();
3413 
3414   // Second, we will fix outgoing control and data edges from 'fltr' node.
3415   for (const Edge* e : fltr->out_edges()) {
3416     if (e->IsControlEdge()) {
3417       auto result = unique_node.insert(e->dst());
3418       if (result.second) {
3419         (*g)->AddControlEdge(new_node, e->dst(), true);
3420       }
3421     } else {
3422       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3423                                     e->dst(), e->dst_input());
3424       DCHECK(new_edge);
3425     }
3426   }
3427 
3428   // Copy device assigned to old node to new node.
3429   // It's ok to use badd or fltr as we have enforced a check that
3430   // both have same device assigned.
3431   new_node->set_assigned_device_name(badd->assigned_device_name());
3432 
3433   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3434           << ", and node: " << fltr->DebugString()
3435           << ", into node:" << new_node->DebugString();
3436 
3437   (*g)->RemoveNode(badd);
3438   (*g)->RemoveNode(fltr);
3439 
3440   return Status::OK();
3441 }
3442 
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3443 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3444                                        Node* n) {
3445   DCHECK(m);
3446   DCHECK(n);
3447 
3448   if (((m->type_string() == csinfo_.bias_add &&
3449         n->type_string() == csinfo_.conv2d)) ||
3450       ((n->type_string() == csinfo_.bias_add &&
3451         m->type_string() == csinfo_.conv2d))) {
3452     return this->MergeConv2DWithBiasAdd(g, m, n);
3453   }
3454   if ((m->type_string() == csinfo_.pad &&
3455        (n->type_string() == csinfo_.conv2d ||
3456         (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3457       (n->type_string() == csinfo_.pad &&
3458        (m->type_string() == csinfo_.conv2d ||
3459         (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3460     return this->MergePadWithConv2D(g, m, n);
3461   }
3462 
3463   if (((m->type_string() == csinfo_.bias_add_grad &&
3464         n->type_string() == csinfo_.conv2d_grad_filter)) ||
3465       ((n->type_string() == csinfo_.bias_add_grad &&
3466         m->type_string() == csinfo_.conv2d_grad_filter))) {
3467     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3468   }
3469 
3470   return Status(error::Code::UNIMPLEMENTED,
3471                 "Unimplemented case for node merge optimization.");
3472 }
3473 
3474 //////////////////////////////////////////////////////////////////////////
3475 //           Helper functions for node rewrite
3476 //////////////////////////////////////////////////////////////////////////
3477 
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3478 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3479     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3480     const RewriteInfo* ri) {
3481   // Get all data inputs.
3482   int num_data_inputs = orig_node->in_edges().size();
3483   // Drop count for control edges from inputs
3484   for (const Edge* e : orig_node->in_edges()) {
3485     if (e->IsControlEdge()) {
3486       num_data_inputs--;
3487     }
3488   }
3489 
3490   gtl::InlinedVector<Node*, 4> control_edges;
3491   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3492   FillInputs(orig_node, &control_edges, &inputs);
3493 
3494   // Build new node. We use same name as original node, but change the op name.
3495   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3496   // Copy user-specified device assigned to original node to new node.
3497   nb.Device(orig_node->def().device());
3498   // Set up new inputs to the rewritten node.
3499   Status s = SetUpInputs(g, inputs, &nb, orig_node);
3500   if (s != Status::OK()) {
3501     return s;
3502   }
3503 
3504   const bool kPartialCopyAttrs = false;
3505   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3506 
3507   // Set the Mkl layer label for this op.
3508   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3509       DataTypeIsQuantized(orig_node->output_type(0))) {
3510     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3511   } else {
3512     nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3513   }
3514   // Finalize graph and get new node.
3515   s = nb.Finalize(&**g, new_node);
3516   if (s != Status::OK()) {
3517     return s;
3518   }
3519 
3520   // In the following code of this function, an unsorted set is used to make
3521   // sure no duplicated edges be added into the new node. Therefore, we can
3522   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3523   // check in the routine.
3524 
3525   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3526   // already copied in BuildNode. We need to handle control edges now.
3527   std::unordered_set<Node*> unique_node;
3528   for (const Edge* e : orig_node->in_edges()) {
3529     if (e->IsControlEdge()) {
3530       auto result = unique_node.insert(e->src());
3531       if (result.second) {
3532         (*g)->AddControlEdge(e->src(), *new_node, true);
3533       }
3534     }
3535   }
3536   unique_node.clear();
3537 
3538   // Copy outgoing edges from 'orig_node' node to new
3539   // 'new_node' node, since the output also follows same ordering among
3540   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3541   // tensors appropriately. Specifically, nth output of the original node
3542   // will become 2*nth output of the Mkl node for the interleaved ordering
3543   // of the tensors. For the contiguous ordering of the tensors, it will be n.
3544   // GetTensorDataIndex provides this mapping function.
3545   for (const Edge* e : orig_node->out_edges()) {
3546     if (e->IsControlEdge()) {
3547       auto result = unique_node.insert(e->dst());
3548       if (result.second) {
3549         (*g)->AddControlEdge(*new_node, e->dst(), true);
3550       }
3551     } else {
3552       auto new_edge = (*g)->AddEdge(
3553           *new_node,
3554           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3555           e->dst(), e->dst_input());
3556       DCHECK(new_edge);
3557     }
3558   }
3559   return Status::OK();
3560 }
3561 
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3562 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3563     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3564     const RewriteInfo* ri) {
3565   // Get all data inputs.
3566   int num_data_inputs = orig_node->in_edges().size();
3567   // Drop count for control edges from inputs
3568   for (const Edge* e : orig_node->in_edges()) {
3569     if (e->IsControlEdge()) {
3570       num_data_inputs--;
3571     }
3572   }
3573   gtl::InlinedVector<Node*, 4> control_edges;
3574   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3575   FillInputs(orig_node, &control_edges, &inputs);
3576 
3577   // Build new node. We use same name as original node, but change the op name.
3578   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3579   // Copy user-specified device assigned to original node to new node.
3580   nb.Device(orig_node->def().device());
3581 
3582   Status s = CopyInputs(orig_node, inputs, &nb);
3583   if (s != Status::OK()) {
3584     return s;
3585   }
3586 
3587   std::vector<NodeBuilder::NodeOut> workspace_tensors;
3588   bool are_workspace_tensors_available = false;
3589   AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3590                            &are_workspace_tensors_available);
3591   if (are_workspace_tensors_available) {
3592     DCHECK_EQ(workspace_tensors.size(), 1);
3593     nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3594   }
3595 
3596   if (!NativeFormatEnabled()) {
3597     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3598   } else {
3599     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3600   }
3601 
3602   nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3603 
3604   // Finalize graph and get new node.
3605   s = nb.Finalize(&**g, new_node);
3606   if (s != Status::OK()) {
3607     return s;
3608   }
3609 
3610   // In the following code of this function, an unsorted set is used to make
3611   // sure no duplicated edges be added into the new node. Therefore, we can
3612   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3613   // check in the routine.
3614 
3615   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3616   // already copied in BuildNode. We need to handle control edges now.
3617   std::unordered_set<Node*> unique_node;
3618   for (const Edge* e : orig_node->in_edges()) {
3619     if (e->IsControlEdge()) {
3620       auto result = unique_node.insert(e->src());
3621       if (result.second) {
3622         (*g)->AddControlEdge(e->src(), *new_node, true);
3623       }
3624     }
3625   }
3626   unique_node.clear();
3627 
3628   // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3629   for (const Edge* e : orig_node->out_edges()) {
3630     if (e->IsControlEdge()) {
3631       auto result = unique_node.insert(e->dst());
3632       if (result.second) {
3633         (*g)->AddControlEdge(*new_node, e->dst(), true);
3634       }
3635     } else {
3636       auto result =
3637           (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3638       DCHECK(result != nullptr);
3639     }
3640   }
3641 
3642   return Status::OK();
3643 }
3644 
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3645 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3646                                          Node* orig_node,
3647                                          const RewriteInfo* ri) {
3648   DCHECK(ri != nullptr);
3649   DCHECK(orig_node != nullptr);
3650 
3651   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3652 
3653   Status ret_status = Status::OK();
3654   Node* new_node = nullptr;
3655   if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3656     ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3657   } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3658     ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3659   } else {
3660     ret_status = Status(error::Code::INVALID_ARGUMENT,
3661                         "Unsupported rewrite cause found."
3662                         "RewriteNode will fail.");
3663   }
3664   TF_CHECK_OK(ret_status);
3665 
3666   // Copy the runtime device assigned from original code to new node.
3667   new_node->set_assigned_device_name(orig_node->assigned_device_name());
3668 
3669   // Delete original node and mark new node as rewritten.
3670   (*g)->RemoveNode(orig_node);
3671 
3672   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3673   return ret_status;
3674 }
3675 
3676 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3677 // having attributes other than "T"?
3678 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3679 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3680 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3681   DataType T1, T2;
3682   DataType Tinput, Tfilter;
3683   bool type_attrs_present = false;
3684 
3685   if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3686       TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3687       mkl_op_registry::IsMklLayoutDependentOp(
3688           mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3689     type_attrs_present = true;
3690   } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3691              TryGetNodeAttr(n->def(), "T2", &T2) &&
3692              mkl_op_registry::IsMklLayoutDependentOp(
3693                  mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3694     type_attrs_present = true;
3695   }
3696 
3697   if (type_attrs_present) {
3698     for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3699       if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3700         return &*ri;
3701       }
3702     }
3703   }
3704 
3705   return nullptr;
3706 }
3707 
3708 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3709 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3710   DCHECK(n);
3711 
3712   // QuantizedOps may have attributes other than "T", so decoupled the check
3713   // with a function, CheckForQuantizedNodeRewrite(const Node*).
3714   const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3715   if (ri != nullptr) return ri;
3716 
3717   // First check if node along with its type is supported by MKL layer.
3718   // We do not want to rewrite an op into Mkl op if types are not supported.
3719   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3720   // MklRelu if type is INT32.
3721   DataType T;
3722   if (!TryGetNodeAttr(n->def(), "T", &T)) {
3723     return nullptr;
3724   }
3725 
3726   // We make an exception for Conv2DGrad and MaxPool related ops as
3727   // the corresponding MKL ops currently do not support the case
3728   // of padding == EXPLICIT yet.
3729   // TODO(intel): support `EXPLICIT` padding for ConvGrad
3730   if (n->type_string() == csinfo_.conv2d_grad_input ||
3731       n->type_string() == csinfo_.conv2d_grad_filter ||
3732       n->type_string() == csinfo_.max_pool ||
3733       n->type_string() == csinfo_.max_pool_grad ||
3734       n->type_string() == csinfo_.max_pool3d ||
3735       n->type_string() == csinfo_.max_pool3d_grad) {
3736     string padding;
3737     TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3738     if (padding == "EXPLICIT") return nullptr;
3739   }
3740 
3741   // We make an exception for __MklDummyConv2DWithBias,
3742   // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3743   // names do not match Mkl node names.
3744   if (n->type_string() != csinfo_.conv2d_with_bias &&
3745       n->type_string() != csinfo_.pad_with_conv2d &&
3746       n->type_string() != csinfo_.pad_with_fused_conv2d &&
3747       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3748       n->type_string() != csinfo_.fused_batch_norm_ex &&
3749       n->type_string() != csinfo_.fused_conv2d &&
3750       n->type_string() != csinfo_.fused_depthwise_conv2d &&
3751       n->type_string() != csinfo_.fused_matmul &&
3752       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3753                                 T)) {
3754     return nullptr;
3755   }
3756 
3757   // We now check if rewrite rule applies for this op. If rewrite rule passes
3758   // for this op, then we rewrite it to Mkl op.
3759   // Find matching RewriteInfo and then check that rewrite rule applies.
3760   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3761     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3762       return &*ri;
3763     }
3764   }
3765 
3766   // Else return not found.
3767   return nullptr;
3768 }
3769 
3770 //////////////////////////////////////////////////////////////////////////
3771 //           Helper functions for node fusion
3772 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3773 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3774     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3775     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3776     string data_format) {
3777   Node* transpose_to_nhwc = nodes[0];
3778   Node* mklop = nodes[1];
3779   Node* transpose_to_nchw = nodes[2];
3780 
3781   const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3782   gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3783   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3784       transpose_nhwc_num_inputs);
3785   FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3786              &transpose_nhwc_in);
3787 
3788   const int mklop_num_inputs = mklop->num_inputs();
3789   gtl::InlinedVector<Node*, 4> mklop_control_edges;
3790   gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3791   FillInputs(mklop, &mklop_control_edges, &mklop_in);
3792 
3793   const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3794   gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3795   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3796       transpose_nchw_num_inputs);
3797   FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3798              &transpose_nchw_in);
3799 
3800   // We use same name as original node, but change the op
3801   // type.
3802   NodeBuilder nb(mklop->name(), mklop->type_string());
3803 
3804   // Storing the output slots of the input nodes.
3805   for (int i = 0; i < mklop_num_inputs; i++) {
3806     if (mklop_in[i].first == transpose_to_nhwc) {
3807       // Fill "x":
3808       nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3809     } else {
3810       // Fill inputs other than "x":
3811       nb.Input(mklop_in[i].first, mklop_in[i].second);
3812     }
3813   }
3814 
3815   copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3816   nb.Attr("data_format", data_format);
3817 
3818   // Copy the device assigned to old node to new node.
3819   nb.Device(mklop->def().device());
3820 
3821   // Create node.
3822   Node* new_node;
3823   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3824   // No need to check if new_node is null because it will be null only when
3825   // Finalize fails.
3826 
3827   // Fill outputs.
3828   for (const Edge* e : transpose_to_nchw->out_edges()) {
3829     if (!e->IsControlEdge()) {
3830       const int kTransposeWithMklOpOutputSlot = 0;
3831       auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3832                                     e->dst(), e->dst_input());
3833       DCHECK(new_edge);
3834     }
3835   }
3836 
3837   // Copy device assigned to old node to new node.
3838   new_node->set_assigned_device_name(mklop->assigned_device_name());
3839 
3840   // Copy requested_device and assigned_device_name_index
3841   new_node->set_requested_device(mklop->requested_device());
3842   new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3843 
3844   (*g)->RemoveNode(transpose_to_nhwc);
3845   (*g)->RemoveNode(mklop);
3846   (*g)->RemoveNode(transpose_to_nchw);
3847 
3848   return Status::OK();
3849 }
3850 
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3851 Status MklLayoutRewritePass::FuseNode(
3852     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3853     const MklLayoutRewritePass::FusionInfo fi) {
3854   return fi.fuse_func(g, nodes, fi.copy_attrs);
3855 }
3856 
3857 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3858 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3859   // Stores matched nodes, in the same order as node_checkers.
3860   std::vector<Node*> nodes;
3861 
3862   for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3863     //
3864     // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3865     // defined in fusion info (ops[0], ops[1], ...),
3866     // a.k.a. "a->b->c" matches "op1->op2->op3"
3867     //
3868 
3869     // Stores the first unvisited outgoing edge of each matched node in "nodes".
3870     std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3871     nodes.clear();
3872 
3873     auto node_checker = fi->node_checkers.begin();
3874     if (a != nullptr && (*node_checker)(a)) {
3875       nodes.push_back(a);
3876       current_neighbor_stack.push(a->out_edges().begin());
3877       ++node_checker;
3878     }
3879 
3880     while (!nodes.empty()) {
3881       auto& current_neighbor_iter = current_neighbor_stack.top();
3882 
3883       if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3884         // Found an unvisited edge. Goes through the edge to get the neighbor.
3885         Node* neighbor_node = (*current_neighbor_iter)->dst();
3886         ++current_neighbor_stack.top();  // Retrieves the next unvisited edge.
3887 
3888         if ((*node_checker)(neighbor_node)) {
3889           // Found a match. Stores the node and moves to the next checker.
3890           nodes.push_back(neighbor_node);
3891           current_neighbor_stack.push(neighbor_node->out_edges().begin());
3892           if (++node_checker == fi->node_checkers.end()) {
3893             return make_tuple(true, nodes, *fi);
3894           }
3895         }
3896       } else {
3897         // Removes the current node since none of its neighbor leads to a
3898         // further match.
3899         nodes.pop_back();
3900         current_neighbor_stack.pop();
3901         --node_checker;
3902       }
3903     }
3904   }
3905 
3906   return make_tuple(false, std::vector<Node*>(), FusionInfo());
3907 }
3908 
3909 ///////////////////////////////////////////////////////////////////////////////
3910 //              Post-rewrite Mkl metadata fixup pass
3911 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3912 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3913                                                       const Edge* e_data,
3914                                                       const Edge* e_metadata) {
3915   if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3916     return false;
3917   }
3918 
3919   Node* n_data = e_data->src();
3920   int n_data_op_slot = e_data->src_output();
3921   int n_metadata_op_slot =
3922       GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3923 
3924   // If the source of meta edge is a constant node (producing dummy Mkl metadata
3925   // tensor), then we will need to fix.
3926   if (IsConstant(e_metadata->src())) {
3927     Node* e_metadata_dst = e_metadata->dst();
3928     int e_metadata_in_slot = e_metadata->dst_input();
3929     auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3930                                   e_metadata_in_slot);
3931     DCHECK(new_edge);
3932 
3933     (*g)->RemoveEdge(e_metadata);
3934     return true;
3935   }
3936 
3937   return false;
3938 }
3939 
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3940 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3941                                                Node* n) {
3942   bool result = false;
3943 
3944   // If graph node is not Mkl node, then return.
3945   DataType T = DT_INVALID;
3946   if (!TryGetNodeAttr(n->def(), "T", &T) ||
3947       !mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
3948     return result;
3949   }
3950 
3951   // If it is Mkl node, then check if the input edges to this node that carry
3952   // Mkl metadata are linked up correctly with the source node.
3953 
3954   // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3955   // data tensors + n for Mkl metadata tensors). We need to check for correct
3956   // connection of n metadata tensors only.
3957   int num_data_inputs = n->num_inputs() / 2;
3958   for (int idx = 0; idx < num_data_inputs; idx++) {
3959     // Get the edge connecting input slot with index (idx).
3960     const Edge* e = nullptr;
3961     TF_CHECK_OK(n->input_edge(idx, &e));
3962 
3963     // If e is control edge, then skip.
3964     if (e->IsControlEdge()) {
3965       continue;
3966     }
3967 
3968     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3969     // node, then we don't need to do anything.
3970     Node* e_src = e->src();
3971     if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3972         mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
3973       // Source node for edge 'e' is Mkl node.
3974       // Destination node and destination input slot of e is node 'n' and 'idx'
3975       // resp.
3976       CHECK_EQ(e->dst(), n);
3977       CHECK_EQ(e->dst_input(), idx);
3978 
3979       // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3980       // 'e'. For that, let's first get the input slot of 'n' where the meta
3981       // edge will feed the value.
3982       int e_meta_in_slot =
3983           GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3984       const Edge* e_meta = nullptr;
3985       TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3986 
3987       // Let's check if we need to fix this meta edge.
3988       if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3989         result = true;
3990       }
3991     }
3992   }
3993 
3994   return result;
3995 }
3996 
3997 ///////////////////////////////////////////////////////////////////////////////
3998 //              Run function for the pass
3999 ///////////////////////////////////////////////////////////////////////////////
4000 
RunPass(std::unique_ptr<Graph> * g)4001 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
4002   bool result = false;
4003   DCHECK(g);
4004 
4005   DumpGraph("Before running MklLayoutRewritePass", &**g);
4006 
4007   std::vector<Node*> order;
4008   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4009   for (Node* n : order) {
4010     // If node is not an op or it cannot run on CPU device, then skip.
4011     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4012       continue;
4013     }
4014 
4015     Node* m = nullptr;
4016     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
4017       // Check if the node 'n' can be merged with any other node. If it can
4018       // be 'm' contains the node with which it can be merged.
4019       string n1_name = n->name();
4020       string n2_name = m->name();
4021 
4022       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4023               << n2_name << " for merging";
4024 
4025       if (MergeNode(g, n, m) == Status::OK()) {
4026         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4027                 << n2_name;
4028         result = true;
4029       }
4030     }
4031   }
4032 
4033   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4034 
4035   order.clear();
4036   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4037   for (Node* n : order) {
4038     // If node is not an op or it cannot run on CPU device, then skip.
4039     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4040       continue;
4041     }
4042 
4043     auto check_result = CheckForNodeFusion(n);
4044     bool found_pattern = std::get<0>(check_result);
4045     std::vector<Node*> nodes = std::get<1>(check_result);
4046     const FusionInfo fi = std::get<2>(check_result);
4047 
4048     // if "found_pattern" is true, we can do the fusion.
4049     if (found_pattern) {
4050       if (FuseNode(g, nodes, fi) == Status::OK()) {
4051         result = true;
4052       }
4053     }
4054   }
4055   DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4056 
4057   order.clear();
4058   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4059   for (Node* n : order) {
4060     // If node is not an op or it cannot run on CPU device, then skip.
4061     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4062       continue;
4063     }
4064 
4065     const RewriteInfo* ri = nullptr;
4066     // We will first search if node is to be rewritten.
4067     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4068       string node_name = n->name();
4069       string op_name = n->type_string();
4070 
4071       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4072               << " with op " << op_name << " for rewrite using"
4073               << " layout optimization.";
4074 
4075       if (RewriteNode(g, n, ri) == Status::OK()) {
4076         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4077                 << " with op " << op_name << " for Mkl layout optimization.";
4078         result = true;
4079       }
4080     }
4081   }
4082 
4083   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4084 
4085   order.clear();
4086   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4087   for (Node* n : order) {
4088     // If node is not an op or it cannot run on CPU device, then skip.
4089     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4090       continue;
4091     }
4092     if (FixMklMetaDataEdges(g, n)) {
4093       string node_name = n->name();
4094       string op_name = n->type_string();
4095 
4096       VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4097               << node_name << " with op " << op_name;
4098       result = true;
4099     }
4100   }
4101   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4102             &**g);
4103 
4104   return result;
4105 }
4106 
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4107 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4108   return MklLayoutRewritePass().RunPass(g);
4109 }
4110 
Run(const GraphOptimizationPassOptions & options)4111 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4112   if (options.graph == nullptr && options.partition_graphs == nullptr) {
4113     return Status::OK();
4114   }
4115   if (DisableMKL()) {
4116     VLOG(2) << "TF-MKL: Disabling MKL";
4117     return Status::OK();
4118   }
4119 
4120   auto process_graph = [&](std::unique_ptr<Graph>* g) {
4121     // Get the ownership of a graph
4122     std::unique_ptr<Graph>* ng = std::move(g);
4123     RunPass(ng);
4124     // Return the ownership of a graph back
4125     g->reset(ng->release());
4126   };
4127 
4128   if (kMklLayoutRewritePassGroup !=
4129       OptimizationPassRegistry::POST_PARTITIONING) {
4130     // For any pre-partitioning phase, a graph is stored in options.graph.
4131     process_graph(options.graph);
4132   } else {
4133     // For post partitioning phase, graphs are stored in
4134     // options.partition_graphs.
4135     for (auto& pg : *options.partition_graphs) {
4136       process_graph(&pg.second);
4137     }
4138   }
4139 
4140   return Status::OK();
4141 }
4142 
4143 }  // namespace tensorflow
4144 
4145 #endif
4146