1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/math.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17 
18 
xnn_create_subgraph(uint32_t external_value_ids,uint32_t flags,xnn_subgraph_t * subgraph_out)19 enum xnn_status xnn_create_subgraph(
20     uint32_t external_value_ids,
21     uint32_t flags,
22     xnn_subgraph_t* subgraph_out)
23 {
24   struct xnn_subgraph* subgraph = NULL;
25   enum xnn_status status = xnn_status_uninitialized;
26 
27   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
28     xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
29     goto error;
30   }
31 
32   status = xnn_status_out_of_memory;
33 
34   subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
35   if (subgraph == NULL) {
36     xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
37     goto error;
38   }
39 
40   subgraph->external_value_ids = external_value_ids;
41 
42   subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
43   if (subgraph->values == NULL) {
44     xnn_log_error("failed to allocate %zu bytes for subgraph values", external_value_ids * sizeof(struct xnn_value));
45     goto error;
46   }
47   for (size_t i = 0; i < external_value_ids; i++) {
48     subgraph->values[i].id = i;
49   }
50   subgraph->num_values = external_value_ids;
51   subgraph->num_reserved_values = external_value_ids;
52 
53   *subgraph_out = subgraph;
54   return xnn_status_success;
55 
56 error:
57   xnn_delete_subgraph(subgraph);
58   return status;
59 }
60 
61 
xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)62 struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
63 {
64   struct xnn_value* values = subgraph->values;
65   const size_t size = subgraph->num_values;
66   const size_t capacity = subgraph->num_reserved_values;
67   if (capacity < size + 1) {
68     const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
69     assert(new_capacity >= size + 1);
70     values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
71     if (values == NULL) {
72       xnn_log_error("failed to allocate %zu bytes for subgraph values",
73         capacity * sizeof(struct xnn_value));
74       return values;
75     }
76 
77     memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
78     subgraph->num_reserved_values = new_capacity;
79     subgraph->values = values;
80   }
81   subgraph->num_values = size + 1;
82   struct xnn_value* new_value = values + size;
83   new_value->id = size;
84   return new_value;
85 }
86 
xnn_node_clear(struct xnn_node * node)87 void xnn_node_clear(struct xnn_node* node) {
88   assert(node != NULL);
89   assert(node->type != xnn_node_type_invalid);
90   memset(node, 0, sizeof(struct xnn_node));
91 }
92 
xnn_value_clear(struct xnn_value * value)93 void xnn_value_clear(struct xnn_value* value) {
94   assert(value != NULL);
95   assert(value->type != xnn_value_type_invalid);
96   memset(value, 0, sizeof(struct xnn_value));
97 }
98 
xnn_subgraph_new_node(xnn_subgraph_t subgraph)99 struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
100 {
101   struct xnn_node* nodes = subgraph->nodes;
102   const size_t size = subgraph->num_nodes;
103   const size_t capacity = subgraph->num_reserved_nodes;
104 
105   if (capacity < size + 1) {
106     const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
107     assert(new_capacity >= size + 1);
108     nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
109     if (nodes == NULL) {
110       xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
111         capacity * sizeof(struct xnn_node));
112       return nodes;
113     }
114 
115     memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
116     subgraph->num_reserved_nodes = new_capacity;
117     subgraph->nodes = nodes;
118   }
119   subgraph->num_nodes = size + 1;
120   struct xnn_node* new_node = nodes + size;
121   new_node->id = size;
122   return new_node;
123 }
124 
125 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW      1
126 #define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
127 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
128 #define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
129 
xnn_check_nchw_compatibility(xnn_subgraph_t subgraph,struct xnn_node * node)130 uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
131   switch (node->type) {
132     case xnn_node_type_convolution_2d:
133       // Supported cases:
134       // - 1x1 convolution (no stride, no dilation, no padding, no groups)
135       // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
136       if (node->params.convolution_2d.groups != 1) {
137         return 0;
138       }
139       if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
140         return 0;
141       }
142       if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
143         if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
144              node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
145         {
146           return 0;
147         }
148         if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
149           return 0;
150         }
151         return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
152       } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
153         if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
154             node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
155         {
156           return 0;
157         }
158         if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
159           return 0;
160         }
161         if (node->params.convolution_2d.group_input_channels != 3) {
162           return 0;
163         }
164         return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
165       }
166       return 0;
167     case xnn_node_type_depthwise_convolution_2d:
168       // Supported cases:
169       // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
170       // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
171       // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
172       // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
173       if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
174         return 0;
175       }
176       if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
177         return 0;
178       }
179       if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
180         return 0;
181       }
182       if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
183         return 0;
184       }
185       switch (node->params.depthwise_convolution_2d.subsampling_height) {
186         case 1:
187         case 2:
188           break;
189         default:
190           return 0;
191       }
192       if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
193         return 0;
194       }
195       switch (node->params.depthwise_convolution_2d.kernel_height) {
196         case 3:
197           return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
198                  node->params.depthwise_convolution_2d.input_padding_right == 1 &&
199                  node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
200                  node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
201         case 5:
202           return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
203                  node->params.depthwise_convolution_2d.input_padding_right == 2 &&
204                  node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
205                  node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
206         default:
207           return 0;
208       }
209     case xnn_node_type_depth_to_space:
210       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
211     case xnn_node_type_global_average_pooling_2d:
212       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
213     case xnn_node_type_add2:
214     case xnn_node_type_multiply2:
215       assert(node->num_inputs == 2);
216       assert(node->num_outputs == 1);
217       if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
218           subgraph->values[node->inputs[1]].shape.num_dims != 4)
219       {
220         return 0;
221       }
222 
223       if (subgraph->values[node->inputs[0]].data != NULL) {
224         // Check that the first input is representable as either a scalar, or a vector
225         size_t num_nonunit_dims = 0;
226         for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
227           if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
228             num_nonunit_dims += 1;
229           }
230         }
231         if (num_nonunit_dims > 1) {
232           return 0;
233         }
234       }
235 
236       if (subgraph->values[node->inputs[1]].data != NULL) {
237         // Check that the second input is representable as either a scalar, or a vector
238         size_t num_nonunit_dims = 0;
239         for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
240           if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
241             num_nonunit_dims += 1;
242           }
243         }
244         if (num_nonunit_dims > 1) {
245           return 0;
246         }
247       }
248 
249       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
250     case xnn_node_type_static_resize_bilinear_2d:
251       return subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
252              subgraph->values[node->inputs[0]].shape.dim[2] > 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
253     case xnn_node_type_abs:
254     case xnn_node_type_bankers_rounding:
255     case xnn_node_type_ceiling:
256     case xnn_node_type_clamp:
257     case xnn_node_type_elu:
258     case xnn_node_type_floor:
259     case xnn_node_type_hardswish:
260     case xnn_node_type_leaky_relu:
261     case xnn_node_type_negate:
262     case xnn_node_type_sigmoid:
263     case xnn_node_type_square:
264       assert(node->num_inputs == 1);
265       assert(node->num_outputs == 1);
266       return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
267     default:
268       return false;
269   }
270 }
271 
xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)272 void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
273 {
274   // Convert parts of the subgraph to NCHW for sparse inference
275   // Step 1: detect NCHW-compatible Nodes
276   // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
277   // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
278   // Step 4: switch Values' layout to NCHW
279   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
280     struct xnn_node* node = &subgraph->nodes[n];
281     node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
282     xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
283       n, xnn_node_type_to_string(node->type),
284       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
285       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
286       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
287   }
288 
289   // Run Shiloach-Vishkin connected components algorithm i.e. find all
290   // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
291   // to all the producer nodes
292   bool update = false;
293   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
294     struct xnn_node* node = &subgraph->nodes[n];
295     node->cluster_leader = n;
296     if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
297       for (uint32_t i = 0; i < node->num_inputs; i++) {
298         const struct xnn_value* value = &subgraph->values[node->inputs[i]];
299         if (value->data != NULL) {
300           // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
301           // during the initial NCHW compatibility check for the Node.
302           continue;
303         }
304         if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
305           // External value, invalid cluster
306           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
307           continue;
308         }
309         const uint32_t producer_id = value->producer;
310         assert(producer_id != XNN_INVALID_NODE_ID);
311         assert(producer_id < n);
312         struct xnn_node* producer_node = &subgraph->nodes[producer_id];
313         if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
314             (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
315         {
316           producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
317           if (producer_node->cluster_leader != node->cluster_leader) {
318             producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
319             update = true;
320           }
321         } else {
322           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
323         }
324       }
325     }
326   }
327   // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
328   // pratically cannot happen.
329   if (!update) {
330     return;
331   }
332   // Propagate the cluster leader to other nodes in the graph untill all the
333   // nodes in the cluster is not updated
334   while (update) {
335     update = false;
336     for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
337       struct xnn_node* node = &subgraph->nodes[n];
338       if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
339         continue;
340       }
341 
342       if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
343         continue;
344       }
345 
346       for (uint32_t i = 0; i < node->num_inputs; i++) {
347         const struct xnn_value* value = &subgraph->values[node->inputs[i]];
348         if (value->data != NULL) {
349           // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
350           // during the initial NCHW compatibility check for the Node.
351           continue;
352         }
353         if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
354           // External value, invalid cluster
355           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
356           continue;
357         }
358         const uint32_t producer_id = value->producer;
359         assert(producer_id != XNN_INVALID_NODE_ID);
360         assert(producer_id < n);
361         struct xnn_node* producer_node = &subgraph->nodes[producer_id];
362         if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
363             (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
364         {
365           producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
366           if (producer_node->cluster_leader != node->cluster_leader) {
367             producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
368             update = true;
369           }
370         } else {
371           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
372         }
373       }
374     }
375   }
376   // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
377   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
378     struct xnn_node* node = &subgraph->nodes[n];
379     subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
380   }
381   // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
382   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
383     struct xnn_node* node = &subgraph->nodes[n];
384     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
385       continue;
386     }
387 
388     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
389       continue;
390     }
391 
392     for (uint32_t i = 0; i < node->num_inputs; i++) {
393       struct xnn_value* value = &subgraph->values[node->inputs[i]];
394       if (value->data != NULL) {
395         // Static data, skip this input value because it doesn't have a producer Node.
396         continue;
397       }
398       assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
399       value->num_nchw_compatible_consumers += 1;
400     }
401   }
402   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
403     struct xnn_node* node = &subgraph->nodes[n];
404     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
405       continue;
406     }
407 
408     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
409       continue;
410     }
411 
412     for (uint32_t i = 0; i < node->num_inputs; i++) {
413       const struct xnn_value* value = &subgraph->values[node->inputs[i]];
414       if (value->data != NULL) {
415         // Static data, skip this input value because it doesn't have a producer Node.
416         continue;
417       }
418       assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
419       assert(value->num_nchw_compatible_consumers > 0);
420       if (value->num_nchw_compatible_consumers != value->num_consumers) {
421         subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
422       }
423     }
424   }
425   // Evaluate if it is profitable to run the model as sparse:
426   // - Compute the number of parameters and zeroes in 1x1 Convolution weights
427   // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
428   //   or with less than 2/3rd of zeroes in 1x1 Convolution filters
429   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
430     struct xnn_node* node = &subgraph->nodes[n];
431     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
432       continue;
433     }
434 
435     if (node->type == xnn_node_type_convolution_2d &&
436         max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
437     {
438       assert(node->num_inputs >= 2);
439 
440       const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
441       assert(filter->data != NULL);
442       assert(filter->shape.num_dims == 4);
443 
444       const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
445       subgraph->nodes[node->cluster_leader].num_params += num_params;
446 
447       const float* data = (const float*) filter->data;
448       size_t num_zeroes = 0;
449       for (size_t i = 0; i < num_params; i++) {
450         num_zeroes += (size_t) (data[i] == 0.0f);
451       }
452       xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
453       subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
454     }
455   }
456   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
457     struct xnn_node* node = &subgraph->nodes[n];
458     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
459       continue;
460     }
461 
462     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
463       continue;
464     }
465 
466     if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
467       xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
468         n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
469       continue;
470     }
471 
472     for (uint32_t i = 0; i < node->num_inputs; i++) {
473       struct xnn_value* value = &subgraph->values[node->inputs[i]];
474       if (value->data != NULL) {
475         // Static data, skip this input value because it doesn't have a producer Node.
476         continue;
477       }
478       assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
479       assert(value->num_nchw_compatible_consumers > 0);
480       assert(value->num_nchw_compatible_consumers == value->num_consumers);
481       if (value->layout != xnn_layout_type_nchw) {
482         value->layout = xnn_layout_type_nchw;
483         xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
484       }
485     }
486   }
487 }
488 
xnn_subgraph_optimize(xnn_subgraph_t subgraph,uint32_t flags)489 enum xnn_status xnn_subgraph_optimize(
490   xnn_subgraph_t subgraph,
491   uint32_t flags)
492 {
493   // Initialize producer/consumer fields to safe defaults.
494   for (uint32_t i = 0; i < subgraph->num_values; i++) {
495     struct xnn_value* value = &subgraph->values[i];
496     value->producer = XNN_INVALID_NODE_ID;
497     value->first_consumer = XNN_INVALID_NODE_ID;
498     value->num_consumers = 0;
499   }
500 
501   // Analyse Nodes' inputs and output and update Values' producer/consumer fields
502   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
503     struct xnn_node* node = &subgraph->nodes[n];
504 
505     for (uint32_t i = 0; i < node->num_inputs; i++) {
506       const uint32_t input_id = node->inputs[i];
507       assert(input_id < subgraph->num_values);
508 
509       if (subgraph->values[input_id].num_consumers++ == 0) {
510         assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
511         subgraph->values[input_id].first_consumer = n;
512       }
513     }
514 
515     for (uint32_t o = 0; o < node->num_outputs; o++) {
516       const uint32_t output_id = node->outputs[o];
517       assert(output_id < subgraph->num_values);
518 
519       assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
520       subgraph->values[output_id].producer = n;
521     }
522   }
523 
524   // Count extra consumer for Values which are external outputs.
525   // Remove unreferenced values.
526   for (uint32_t i = 0; i < subgraph->num_values; i++) {
527     struct xnn_value* value = &subgraph->values[i];
528     if (value->type == xnn_value_type_invalid) {
529       continue;
530     }
531 
532     if (value->flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
533       value->num_consumers += 1;
534     }
535     if ((value->flags & XNN_VALUE_FLAG_EXTERNAL_INPUT) == 0 && value->num_consumers == 0) {
536       xnn_value_clear(value);
537     }
538   }
539 
540   // Fuse Nodes where possible
541   for (uint32_t i = 0; i < subgraph->num_values; i++) {
542     struct xnn_value* value = &subgraph->values[i];
543     if (value->num_consumers == 1) {
544       const uint32_t producer_id = value->producer;
545       if (producer_id == XNN_INVALID_NODE_ID) {
546         continue;
547       }
548       assert(producer_id < subgraph->num_nodes);
549 
550       const uint32_t consumer_id = value->first_consumer;
551       if (consumer_id == XNN_INVALID_NODE_ID) {
552         continue;
553       }
554       assert(consumer_id < subgraph->num_nodes);
555 
556       struct xnn_node* producer = &subgraph->nodes[producer_id];
557       assert(producer->type != xnn_node_type_invalid);
558       struct xnn_node* consumer = &subgraph->nodes[consumer_id];
559       assert(consumer->type != xnn_node_type_invalid);
560 
561       // Try to fuse Clamp Node upstream into producer Node
562       if (consumer->type == xnn_node_type_clamp) {
563         switch (producer->type) {
564           case xnn_node_type_add2:
565           case xnn_node_type_average_pooling_2d:
566           case xnn_node_type_clamp:
567           case xnn_node_type_convolution_2d:
568           case xnn_node_type_divide:
569           case xnn_node_type_deconvolution_2d:
570           case xnn_node_type_depthwise_convolution_2d:
571           case xnn_node_type_fully_connected:
572           case xnn_node_type_multiply2:
573           case xnn_node_type_max_pooling_2d:
574           case xnn_node_type_subtract:
575             xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
576             assert(producer->num_outputs == 1);
577             assert(consumer->num_inputs == 1);
578             assert(consumer->num_outputs == 1);
579 
580             const uint32_t fused_output_id = consumer->outputs[0];
581             assert(fused_output_id < subgraph->num_values);
582             subgraph->values[fused_output_id].producer = producer_id;
583             producer->outputs[0] = fused_output_id;
584 
585             producer->activation.output_min =
586               math_max_f32(producer->activation.output_min, consumer->activation.output_min);
587             producer->activation.output_max =
588               math_min_f32(producer->activation.output_max, consumer->activation.output_max);
589 
590             xnn_node_clear(consumer);
591             xnn_value_clear(value);
592             break;
593           default:
594             break;
595         }
596       }
597       // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
598       if (producer->type == xnn_node_type_static_constant_pad) {
599         assert(producer->num_inputs == 1);
600         assert(producer->num_outputs == 1);
601         const bool is_spatial_2d_zero_padding = value->shape.num_dims == 4 &&
602           (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
603            producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0 &&
604            producer->params.static_pad.padding_value == 0;
605         switch (consumer->type) {
606           case xnn_node_type_convolution_2d:
607             if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
608               xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
609                 consumer_id, producer_id);
610               assert(consumer->num_inputs >= 1);
611               assert(consumer->inputs[0] == producer->outputs[0]);
612 
613               consumer->params.convolution_2d.input_padding_top    += producer->params.static_pad.pre_paddings[1];
614               consumer->params.convolution_2d.input_padding_right  += producer->params.static_pad.post_paddings[2];
615               consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
616               consumer->params.convolution_2d.input_padding_left   += producer->params.static_pad.pre_paddings[2];
617 
618               consumer->inputs[0] = producer->inputs[0];
619 
620               const uint32_t fused_input_id = producer->inputs[0];
621               assert(fused_input_id < subgraph->num_values);
622               if (subgraph->values[fused_input_id].first_consumer == producer_id) {
623                 subgraph->values[fused_input_id].first_consumer = consumer_id;
624               }
625 
626               xnn_node_clear(producer);
627               xnn_value_clear(value);
628             }
629             break;
630           case xnn_node_type_depthwise_convolution_2d:
631             if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
632               xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
633                 consumer_id, producer_id);
634               assert(consumer->num_inputs >= 1);
635               assert(consumer->inputs[0] == producer->outputs[0]);
636 
637               consumer->params.depthwise_convolution_2d.input_padding_top +=
638                 producer->params.static_pad.pre_paddings[1];
639               consumer->params.depthwise_convolution_2d.input_padding_right +=
640                 producer->params.static_pad.post_paddings[2];
641               consumer->params.depthwise_convolution_2d.input_padding_bottom +=
642                 producer->params.static_pad.post_paddings[1];
643               consumer->params.depthwise_convolution_2d.input_padding_left +=
644                 producer->params.static_pad.pre_paddings[2];
645 
646               consumer->inputs[0] = producer->inputs[0];
647 
648               const uint32_t fused_input_id = producer->inputs[0];
649               assert(fused_input_id < subgraph->num_values);
650               if (subgraph->values[fused_input_id].first_consumer == producer_id) {
651                 subgraph->values[fused_input_id].first_consumer = consumer_id;
652               }
653 
654               xnn_node_clear(producer);
655               xnn_value_clear(value);
656             }
657             break;
658           default:
659             break;
660         }
661       }
662     }
663   }
664 
665   #if XNN_ENABLE_SPARSE
666     if ((flags & XNN_FLAG_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
667       xnn_subgraph_rewrite_for_nchw(subgraph);
668     }
669   #endif
670 
671   return xnn_status_success;
672 }
673 
xnn_delete_subgraph(xnn_subgraph_t subgraph)674 enum xnn_status xnn_delete_subgraph(
675   xnn_subgraph_t subgraph)
676 {
677   if (subgraph != NULL) {
678     memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
679     xnn_release_memory(subgraph->nodes);
680 
681     memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
682     xnn_release_memory(subgraph->values);
683 
684     memset(subgraph, 0, sizeof(struct xnn_subgraph));
685     xnn_release_memory(subgraph);
686   }
687   return xnn_status_success;
688 }
689