1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/params.h>
20 
21 
create_channel_shuffle_nc(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,enum xnn_operator_type operator_type,xnn_operator_t * channel_shuffle_op_out)22 static enum xnn_status create_channel_shuffle_nc(
23   size_t groups,
24   size_t group_channels,
25   size_t input_stride,
26   size_t output_stride,
27   uint32_t flags,
28   enum xnn_operator_type operator_type,
29   xnn_operator_t* channel_shuffle_op_out)
30 {
31   xnn_operator_t channel_shuffle_op = NULL;
32   enum xnn_status status = xnn_status_uninitialized;
33 
34   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
35     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
36       xnn_operator_type_to_string(operator_type));
37     goto error;
38   }
39 
40   status = xnn_status_invalid_parameter;
41 
42   if (groups <= 1) {
43     xnn_log_error(
44       "failed to create %s operator with %zu groups: at least two groups required",
45       xnn_operator_type_to_string(operator_type), groups);
46     goto error;
47   }
48 
49   if (group_channels == 0) {
50     xnn_log_error(
51       "failed to create %s operator with %zu group channels: number of group channels must be non-zero",
52       xnn_operator_type_to_string(operator_type), group_channels);
53     goto error;
54   }
55 
56   const size_t channels = groups * group_channels;
57   if (input_stride < channels) {
58     xnn_log_error(
59       "failed to create %s operator with input element stride of %zu: "
60       "stride must be at least as large as the number of channels (%zux%zu)",
61       xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels);
62     goto error;
63   }
64 
65   if (output_stride < channels) {
66     xnn_log_error(
67       "failed to create %s operator with output element stride of %zu: "
68       "stride must be at least as large as the number of channels (%zux%zu)",
69       xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels);
70     goto error;
71   }
72 
73   status = xnn_status_out_of_memory;
74 
75   channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
76   if (channel_shuffle_op == NULL) {
77     xnn_log_error(
78       "failed to allocate %zu bytes for %s operator descriptor",
79       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
80     goto error;
81   }
82 
83   channel_shuffle_op->groups = groups;
84   channel_shuffle_op->group_channels = group_channels;
85   channel_shuffle_op->input_pixel_stride = input_stride;
86   channel_shuffle_op->output_pixel_stride = output_stride;
87 
88   channel_shuffle_op->type = operator_type;
89 
90   channel_shuffle_op->state = xnn_run_state_invalid;
91 
92   *channel_shuffle_op_out = channel_shuffle_op;
93   return xnn_status_success;
94 
95 error:
96   xnn_delete_operator(channel_shuffle_op);
97   return status;
98 }
99 
100 
xnn_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)101 enum xnn_status xnn_create_channel_shuffle_nc_x8(
102     size_t groups,
103     size_t group_channels,
104     size_t input_stride,
105     size_t output_stride,
106     uint32_t flags,
107     xnn_operator_t* channel_shuffle_op_out)
108 {
109   return create_channel_shuffle_nc(
110     groups,
111     group_channels,
112     input_stride,
113     output_stride,
114     flags,
115     xnn_operator_type_channel_shuffle_nc_x8,
116     channel_shuffle_op_out);
117 }
118 
xnn_create_channel_shuffle_nc_x32(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)119 enum xnn_status xnn_create_channel_shuffle_nc_x32(
120     size_t groups,
121     size_t group_channels,
122     size_t input_stride,
123     size_t output_stride,
124     uint32_t flags,
125     xnn_operator_t* channel_shuffle_op_out)
126 {
127   return create_channel_shuffle_nc(
128     groups,
129     group_channels,
130     input_stride,
131     output_stride,
132     flags,
133     xnn_operator_type_channel_shuffle_nc_x32,
134     channel_shuffle_op_out);
135 }
136 
setup_channel_shuffle_nc(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS (1)])137 static enum xnn_status setup_channel_shuffle_nc(
138     xnn_operator_t channel_shuffle_op,
139     size_t batch_size,
140     const void* input,
141     void* output,
142     uint32_t log2_element_size,
143     const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS(1)])
144 {
145   channel_shuffle_op->state = xnn_run_state_invalid;
146 
147   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
148     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
149       xnn_operator_type_to_string(channel_shuffle_op->type));
150     return xnn_status_uninitialized;
151   }
152 
153   if (batch_size == 0) {
154     channel_shuffle_op->state = xnn_run_state_skip;
155     return xnn_status_success;
156   }
157 
158   channel_shuffle_op->batch_size = batch_size;
159   channel_shuffle_op->input = input;
160   channel_shuffle_op->output = output;
161 
162   const size_t groups = channel_shuffle_op->groups;
163   channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
164     .x = input,
165     .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
166     .y = output,
167     .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
168     .n = channel_shuffle_op->group_channels << log2_element_size,
169     .m = groups,
170   };
171   channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
172   channel_shuffle_op->compute.range[0] = batch_size;
173   switch (groups) {
174     case 2:
175       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
176       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
177       break;
178     case 3:
179       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
180       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
181       break;
182     case 4:
183       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
184       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
185       break;
186     default:
187       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
188       channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
189       break;
190     case 0:
191     case 1:
192       XNN_UNREACHABLE;
193   }
194   channel_shuffle_op->state = xnn_run_state_ready;
195 
196   return xnn_status_success;
197 }
198 
xnn_setup_channel_shuffle_nc_x8(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)199 enum xnn_status xnn_setup_channel_shuffle_nc_x8(
200     xnn_operator_t channel_shuffle_op,
201     size_t batch_size,
202     const void* input,
203     void* output,
204     pthreadpool_t threadpool)
205 {
206   if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
207     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
208       xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8),
209       xnn_operator_type_to_string(channel_shuffle_op->type));
210     return xnn_status_invalid_parameter;
211   }
212 
213   return setup_channel_shuffle_nc(
214     channel_shuffle_op,
215     batch_size,
216     input,
217     output,
218     0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
219     &xnn_params.x8.zip);
220 }
221 
xnn_setup_channel_shuffle_nc_x32(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)222 enum xnn_status xnn_setup_channel_shuffle_nc_x32(
223     xnn_operator_t channel_shuffle_op,
224     size_t batch_size,
225     const void* input,
226     void* output,
227     pthreadpool_t threadpool)
228 {
229   if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
230     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
231       xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32),
232       xnn_operator_type_to_string(channel_shuffle_op->type));
233     return xnn_status_invalid_parameter;
234   }
235 
236   return setup_channel_shuffle_nc(
237     channel_shuffle_op,
238     batch_size,
239     input,
240     output,
241     2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
242     &xnn_params.x32.zip);
243 }
244