1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/operator.h>
16 #include <xnnpack/params-init.h>
17 #include <xnnpack/params.h>
18 
19 
create_unary_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,const void * params,size_t params_size,enum xnn_operator_type operator_type,xnn_univector_ukernel_function ukernel,xnn_operator_t * unary_elementwise_op_out)20 static enum xnn_status create_unary_elementwise_nc(
21     size_t channels,
22     size_t input_stride,
23     size_t output_stride,
24     uint32_t flags,
25     const void* params,
26     size_t params_size,
27     enum xnn_operator_type operator_type,
28     xnn_univector_ukernel_function ukernel,
29     xnn_operator_t* unary_elementwise_op_out)
30 {
31   xnn_operator_t unary_elementwise_op = NULL;
32 
33   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
35       xnn_operator_type_to_string(operator_type));
36     return xnn_status_uninitialized;
37   }
38 
39   if (channels == 0) {
40     xnn_log_error(
41       "failed to create %s operator with %zu channels: number of channels must be non-zero",
42       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), channels);
43     return xnn_status_invalid_parameter;
44   }
45 
46   if (input_stride < channels) {
47     xnn_log_error(
48       "failed to create %s operator with input element stride of %zu: "
49       "stride must be at least as large as the number of channels (%zu)",
50       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), input_stride, channels);
51     return xnn_status_invalid_parameter;
52   }
53 
54   if (output_stride < channels) {
55     xnn_log_error(
56       "failed to create %s operator with output element stride of %zu: "
57       "stride must be at least as large as the number of channels (%zu)",
58       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), output_stride, channels);
59     return xnn_status_invalid_parameter;
60   }
61 
62   unary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
63   if (unary_elementwise_op == NULL) {
64     xnn_log_error(
65       "failed to allocate %zu bytes for %s operator descriptor",
66       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
67     return xnn_status_out_of_memory;
68   }
69 
70   unary_elementwise_op->channels = channels;
71   unary_elementwise_op->input_pixel_stride = input_stride;
72   unary_elementwise_op->output_pixel_stride = output_stride;
73   if (params_size != 0) {
74     memcpy(&unary_elementwise_op->params, params, params_size);
75   }
76 
77   unary_elementwise_op->ukernel.vunary.function = ukernel;
78   unary_elementwise_op->type = operator_type;
79 
80   unary_elementwise_op->state = xnn_run_state_invalid;
81 
82   *unary_elementwise_op_out = unary_elementwise_op;
83   return xnn_status_success;
84 }
85 
setup_unary_elementwise_nc(xnn_operator_t unary_elementwise_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const void * params,size_t params_size)86 static enum xnn_status setup_unary_elementwise_nc(
87     xnn_operator_t unary_elementwise_op,
88     size_t batch_size,
89     const void* input,
90     void* output,
91     uint32_t log2_element_size,
92     const void* params,
93     size_t params_size)
94 {
95   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
96     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
97       xnn_operator_type_to_string(unary_elementwise_op->type));
98     return xnn_status_uninitialized;
99   }
100 
101   if (batch_size == 0) {
102     unary_elementwise_op->state = xnn_run_state_skip;
103     return xnn_status_success;
104   }
105 
106   const size_t channels = unary_elementwise_op->channels;
107   const size_t input_stride = unary_elementwise_op->input_pixel_stride;
108   const size_t output_stride = unary_elementwise_op->output_pixel_stride;
109 
110   xnn_univector_ukernel_function ukernel = unary_elementwise_op->ukernel.vunary.function;
111 
112   if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
113     const size_t block_size = 4096;
114     unary_elementwise_op->context.univector_contiguous = (struct univector_contiguous_context) {
115       .x = input,
116       .x_stride = input_stride << log2_element_size,
117       .y = output,
118       .y_stride = output_stride << log2_element_size,
119       .ukernel = ukernel,
120     };
121     if (params_size != 0) {
122       memcpy(&unary_elementwise_op->context.univector_contiguous.params, params, params_size);
123     }
124     unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
125     unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
126     unary_elementwise_op->compute.range[0] = (batch_size * channels) << log2_element_size;
127     unary_elementwise_op->compute.tile[0] = block_size;
128   } else {
129     unary_elementwise_op->context.univector_strided = (struct univector_strided_context) {
130       .n = channels << log2_element_size,
131       .x = input,
132       .x_stride = input_stride << log2_element_size,
133       .y = output,
134       .y_stride = output_stride << log2_element_size,
135       .ukernel = ukernel,
136     };
137     if (params_size != 0) {
138       memcpy(&unary_elementwise_op->context.univector_strided.params, params, params_size);
139     }
140     unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
141     unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
142     unary_elementwise_op->compute.range[0] = batch_size;
143     unary_elementwise_op->compute.tile[0] = 1;
144   }
145   unary_elementwise_op->state = xnn_run_state_ready;
146 
147   return xnn_status_success;
148 }
149 
xnn_create_clamp_nc_u8(size_t channels,size_t input_stride,size_t output_stride,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * clamp_op_out)150 enum xnn_status xnn_create_clamp_nc_u8(
151     size_t channels,
152     size_t input_stride,
153     size_t output_stride,
154     uint8_t output_min,
155     uint8_t output_max,
156     uint32_t flags,
157     xnn_operator_t* clamp_op_out)
158 {
159   if (output_min >= output_max) {
160     xnn_log_error(
161       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
162       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8), output_min, output_max);
163     return xnn_status_invalid_parameter;
164   }
165 
166   const union xnn_u8_minmax_params params = xnn_init_u8_minmax_params(output_min, output_max);
167   return create_unary_elementwise_nc(
168     channels, input_stride, output_stride, flags,
169     &params, sizeof(params),
170     xnn_operator_type_clamp_nc_u8,
171     xnn_params.u8.clamp,
172     clamp_op_out);
173 }
174 
xnn_create_clamp_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float output_min,float output_max,uint32_t flags,xnn_operator_t * clamp_op_out)175 enum xnn_status xnn_create_clamp_nc_f32(
176     size_t channels,
177     size_t input_stride,
178     size_t output_stride,
179     float output_min,
180     float output_max,
181     uint32_t flags,
182     xnn_operator_t* clamp_op_out)
183 {
184   if (isnan(output_min)) {
185     xnn_log_error(
186       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
187       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
188     return xnn_status_invalid_parameter;
189   }
190 
191   if (isnan(output_max)) {
192     xnn_log_error(
193       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
194       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
195     return xnn_status_invalid_parameter;
196   }
197 
198   if (output_min >= output_max) {
199     xnn_log_error(
200       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
201       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), output_min, output_max);
202     return xnn_status_invalid_parameter;
203   }
204 
205   const bool relu_activation = (output_max == INFINITY) && (output_min == 0.0f);
206   xnn_univector_ukernel_function clamp_ukernel = (relu_activation && (xnn_params.f32.relu != NULL)) ?
207     xnn_params.f32.relu : xnn_params.f32.clamp;
208 
209   const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
210   return create_unary_elementwise_nc(
211     channels, input_stride, output_stride, flags,
212     &params, sizeof(params),
213     xnn_operator_type_clamp_nc_f32,
214     clamp_ukernel,
215     clamp_op_out);
216 }
217 
xnn_create_abs_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * abs_op_out)218 enum xnn_status xnn_create_abs_nc_f32(
219     size_t channels,
220     size_t input_stride,
221     size_t output_stride,
222     uint32_t flags,
223     xnn_operator_t* abs_op_out)
224 {
225   const union xnn_f32_abs_params params = xnn_init_f32_abs_params();
226   return create_unary_elementwise_nc(
227     channels, input_stride, output_stride, flags,
228     &params, sizeof(params),
229     xnn_operator_type_abs_nc_f32,
230     xnn_params.f32.abs,
231     abs_op_out);
232 }
233 
xnn_create_bankers_rounding_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * rounding_op_out)234 enum xnn_status xnn_create_bankers_rounding_nc_f32(
235     size_t channels,
236     size_t input_stride,
237     size_t output_stride,
238     uint32_t flags,
239     xnn_operator_t* rounding_op_out)
240 {
241   const union xnn_f32_rnd_params params = xnn_init_f32_rnd_params();
242   return create_unary_elementwise_nc(
243     channels, input_stride, output_stride, flags,
244     &params, sizeof(params),
245     xnn_operator_type_bankers_rounding_nc_f32,
246     xnn_params.f32.rndne,
247     rounding_op_out);
248 }
249 
xnn_create_ceiling_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * ceiling_op_out)250 enum xnn_status xnn_create_ceiling_nc_f32(
251     size_t channels,
252     size_t input_stride,
253     size_t output_stride,
254     uint32_t flags,
255     xnn_operator_t* ceiling_op_out)
256 {
257   const union xnn_f32_rnd_params params = xnn_init_f32_rnd_params();
258   return create_unary_elementwise_nc(
259     channels, input_stride, output_stride, flags,
260     &params, sizeof(params),
261     xnn_operator_type_ceiling_nc_f32,
262     xnn_params.f32.rndu,
263     ceiling_op_out);
264 }
265 
xnn_create_copy_nc_x32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)266 enum xnn_status xnn_create_copy_nc_x32(
267     size_t channels,
268     size_t input_stride,
269     size_t output_stride,
270     uint32_t flags,
271     xnn_operator_t* copy_op_out)
272 {
273   return create_unary_elementwise_nc(
274     channels, input_stride, output_stride, flags,
275     NULL, 0,
276     xnn_operator_type_copy_nc_x32,
277     xnn_params.xx.copy,
278     copy_op_out);
279 }
280 
xnn_create_elu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float alpha,uint32_t flags,xnn_operator_t * elu_op_out)281 enum xnn_status xnn_create_elu_nc_f32(
282   size_t channels,
283   size_t input_stride,
284   size_t output_stride,
285   float alpha,
286   uint32_t flags,
287   xnn_operator_t* elu_op_out)
288 {
289   if (alpha <= 0.0f || !isnormal(alpha)) {
290     xnn_log_error(
291       "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
292       xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32), alpha);
293     return xnn_status_invalid_parameter;
294   }
295 
296   const union xnn_f32_elu_params params = xnn_init_f32_elu_params(1.0f /* prescale */, alpha, 1.0f /* beta */);
297   return create_unary_elementwise_nc(
298     channels, input_stride, output_stride, flags,
299     &params, sizeof(params),
300     xnn_operator_type_elu_nc_f32,
301     xnn_params.f32.elu,
302     elu_op_out);
303 }
304 
xnn_create_floor_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * floor_op_out)305 enum xnn_status xnn_create_floor_nc_f32(
306     size_t channels,
307     size_t input_stride,
308     size_t output_stride,
309     uint32_t flags,
310     xnn_operator_t* floor_op_out)
311 {
312   const union xnn_f32_rnd_params params = xnn_init_f32_rnd_params();
313   return create_unary_elementwise_nc(
314     channels, input_stride, output_stride, flags,
315     &params, sizeof(params),
316     xnn_operator_type_floor_nc_f32,
317     xnn_params.f32.rndd,
318     floor_op_out);
319 }
320 
xnn_create_hardswish_nc_f16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)321 enum xnn_status xnn_create_hardswish_nc_f16(
322     size_t channels,
323     size_t input_stride,
324     size_t output_stride,
325     uint32_t flags,
326     xnn_operator_t* hardswish_op_out)
327 {
328   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
329     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
330       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
331     return xnn_status_uninitialized;
332   }
333 
334   if ((xnn_params.init_flags & XNN_INIT_FLAG_F16) != XNN_INIT_FLAG_F16) {
335     xnn_log_error("failed to create %s operator: operations on data type are not supported",
336       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
337     return xnn_status_unsupported_hardware;
338   }
339 
340   const struct xnn_f16_hswish_params params = xnn_init_f16_hswish_params();
341   return create_unary_elementwise_nc(
342     channels, input_stride, output_stride, flags,
343     &params, sizeof(params),
344     xnn_operator_type_hardswish_nc_f16,
345     xnn_params.f16.hswish,
346     hardswish_op_out);
347 }
348 
xnn_create_hardswish_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)349 enum xnn_status xnn_create_hardswish_nc_f32(
350     size_t channels,
351     size_t input_stride,
352     size_t output_stride,
353     uint32_t flags,
354     xnn_operator_t* hardswish_op_out)
355 {
356   const union xnn_f32_hswish_params params = xnn_init_f32_hswish_params();
357   return create_unary_elementwise_nc(
358     channels, input_stride, output_stride, flags,
359     &params, sizeof(params),
360     xnn_operator_type_hardswish_nc_f32,
361     xnn_params.f32.hswish,
362     hardswish_op_out);
363 }
364 
xnn_create_leaky_relu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float negative_slope,uint32_t flags,xnn_operator_t * leaky_relu_op_out)365 enum xnn_status xnn_create_leaky_relu_nc_f32(
366   size_t channels,
367   size_t input_stride,
368   size_t output_stride,
369   float negative_slope,
370   uint32_t flags,
371   xnn_operator_t* leaky_relu_op_out)
372 {
373   if (!isfinite(negative_slope)) {
374     xnn_log_error(
375       "failed to create %s operator with %f negative slope: finite number expected",
376       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
377       negative_slope);
378     return xnn_status_invalid_parameter;
379   }
380 
381   const union xnn_f32_lrelu_params params = xnn_init_f32_lrelu_params(negative_slope);
382   return create_unary_elementwise_nc(
383     channels, input_stride, output_stride, flags,
384     &params, sizeof(params),
385     xnn_operator_type_leaky_relu_nc_f32,
386     xnn_params.f32.lrelu,
387     leaky_relu_op_out);
388 }
389 
xnn_create_negate_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * negate_op_out)390 enum xnn_status xnn_create_negate_nc_f32(
391     size_t channels,
392     size_t input_stride,
393     size_t output_stride,
394     uint32_t flags,
395     xnn_operator_t* negate_op_out)
396 {
397   const union xnn_f32_neg_params params = xnn_init_f32_neg_params();
398   return create_unary_elementwise_nc(
399     channels, input_stride, output_stride, flags,
400     &params, sizeof(params),
401     xnn_operator_type_negate_nc_f32,
402     xnn_params.f32.neg,
403     negate_op_out);
404 }
405 
xnn_create_sigmoid_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sigmoid_op_out)406 enum xnn_status xnn_create_sigmoid_nc_f32(
407     size_t channels,
408     size_t input_stride,
409     size_t output_stride,
410     uint32_t flags,
411     xnn_operator_t* sigmoid_op_out)
412 {
413   return create_unary_elementwise_nc(
414     channels, input_stride, output_stride, flags,
415     NULL, 0,
416     xnn_operator_type_sigmoid_nc_f32,
417     xnn_params.f32.sigmoid,
418     sigmoid_op_out);
419 }
420 
xnn_create_square_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * square_op_out)421 enum xnn_status xnn_create_square_nc_f32(
422     size_t channels,
423     size_t input_stride,
424     size_t output_stride,
425     uint32_t flags,
426     xnn_operator_t* square_op_out)
427 {
428   return create_unary_elementwise_nc(
429     channels, input_stride, output_stride, flags,
430     NULL, 0,
431     xnn_operator_type_square_nc_f32,
432     xnn_params.f32.sqr,
433     square_op_out);
434 }
435 
xnn_create_square_root_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sqrt_op_out)436 enum xnn_status xnn_create_square_root_nc_f32(
437     size_t channels,
438     size_t input_stride,
439     size_t output_stride,
440     uint32_t flags,
441     xnn_operator_t* sqrt_op_out)
442 {
443   const union xnn_f32_sqrt_params params = xnn_init_f32_sqrt_params();
444   return create_unary_elementwise_nc(
445     channels, input_stride, output_stride, flags,
446     &params, sizeof(params),
447     xnn_operator_type_square_root_nc_f32,
448     xnn_params.f32.sqrt,
449     sqrt_op_out);
450 }
451 
xnn_create_truncation_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * truncation_op_out)452 enum xnn_status xnn_create_truncation_nc_f32(
453     size_t channels,
454     size_t input_stride,
455     size_t output_stride,
456     uint32_t flags,
457     xnn_operator_t* truncation_op_out)
458 {
459   const union xnn_f32_rnd_params params = xnn_init_f32_rnd_params();
460   return create_unary_elementwise_nc(
461     channels, input_stride, output_stride, flags,
462     &params, sizeof(params),
463     xnn_operator_type_truncation_nc_f32,
464     xnn_params.f32.rndz,
465     truncation_op_out);
466 }
467 
xnn_setup_abs_nc_f32(xnn_operator_t abs_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)468 enum xnn_status xnn_setup_abs_nc_f32(
469     xnn_operator_t abs_op,
470     size_t batch_size,
471     const float* input,
472     float* output,
473     pthreadpool_t threadpool)
474 {
475   if (abs_op->type != xnn_operator_type_abs_nc_f32) {
476     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
477       xnn_operator_type_to_string(xnn_operator_type_abs_nc_f32),
478       xnn_operator_type_to_string(abs_op->type));
479     return xnn_status_invalid_parameter;
480   }
481   abs_op->state = xnn_run_state_invalid;
482 
483   return setup_unary_elementwise_nc(
484     abs_op,
485     batch_size, input, output,
486     2 /* log2(sizeof(float)) */,
487     &abs_op->params.f32_abs, sizeof(abs_op->params.f32_abs));
488 }
489 
xnn_setup_bankers_rounding_nc_f32(xnn_operator_t rounding_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)490 enum xnn_status xnn_setup_bankers_rounding_nc_f32(
491     xnn_operator_t rounding_op,
492     size_t batch_size,
493     const float* input,
494     float* output,
495     pthreadpool_t threadpool)
496 {
497   if (rounding_op->type != xnn_operator_type_bankers_rounding_nc_f32) {
498     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
499       xnn_operator_type_to_string(xnn_operator_type_bankers_rounding_nc_f32),
500       xnn_operator_type_to_string(rounding_op->type));
501     return xnn_status_invalid_parameter;
502   }
503   rounding_op->state = xnn_run_state_invalid;
504 
505   return setup_unary_elementwise_nc(
506     rounding_op,
507     batch_size, input, output,
508     2 /* log2(sizeof(float)) */,
509     &rounding_op->params.f32_rnd, sizeof(rounding_op->params.f32_rnd));
510 }
511 
xnn_setup_ceiling_nc_f32(xnn_operator_t ceiling_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)512 enum xnn_status xnn_setup_ceiling_nc_f32(
513     xnn_operator_t ceiling_op,
514     size_t batch_size,
515     const float* input,
516     float* output,
517     pthreadpool_t threadpool)
518 {
519   if (ceiling_op->type != xnn_operator_type_ceiling_nc_f32) {
520     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
521       xnn_operator_type_to_string(xnn_operator_type_ceiling_nc_f32),
522       xnn_operator_type_to_string(ceiling_op->type));
523     return xnn_status_invalid_parameter;
524   }
525   ceiling_op->state = xnn_run_state_invalid;
526 
527   return setup_unary_elementwise_nc(
528     ceiling_op,
529     batch_size, input, output,
530     2 /* log2(sizeof(float)) */,
531     &ceiling_op->params.f32_rnd, sizeof(ceiling_op->params.f32_rnd));
532 }
533 
xnn_setup_clamp_nc_u8(xnn_operator_t clamp_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)534 enum xnn_status xnn_setup_clamp_nc_u8(
535     xnn_operator_t clamp_op,
536     size_t batch_size,
537     const uint8_t* input,
538     uint8_t* output,
539     pthreadpool_t threadpool)
540 {
541   if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
542     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
543       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
544       xnn_operator_type_to_string(clamp_op->type));
545     return xnn_status_invalid_parameter;
546   }
547   clamp_op->state = xnn_run_state_invalid;
548 
549   return setup_unary_elementwise_nc(
550     clamp_op,
551     batch_size, input, output,
552     0 /* log2(sizeof(uint8_t)) */,
553     &clamp_op->params.u8_minmax, sizeof(clamp_op->params.u8_minmax));
554 }
555 
xnn_setup_clamp_nc_f32(xnn_operator_t clamp_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)556 enum xnn_status xnn_setup_clamp_nc_f32(
557     xnn_operator_t clamp_op,
558     size_t batch_size,
559     const float* input,
560     float* output,
561     pthreadpool_t threadpool)
562 {
563   if (clamp_op->type != xnn_operator_type_clamp_nc_f32) {
564     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
565       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32),
566       xnn_operator_type_to_string(clamp_op->type));
567     return xnn_status_invalid_parameter;
568   }
569   clamp_op->state = xnn_run_state_invalid;
570 
571   return setup_unary_elementwise_nc(
572     clamp_op,
573     batch_size, input, output,
574     2 /* log2(sizeof(float)) */,
575     &clamp_op->params.f32_minmax, sizeof(clamp_op->params.f32_minmax));
576 }
577 
xnn_setup_copy_nc_x32(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)578 enum xnn_status xnn_setup_copy_nc_x32(
579     xnn_operator_t copy_op,
580     size_t batch_size,
581     const void* input,
582     void* output,
583     pthreadpool_t threadpool)
584 {
585   if (copy_op->type != xnn_operator_type_copy_nc_x32) {
586     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
587       xnn_operator_type_to_string(xnn_operator_type_copy_nc_x32),
588       xnn_operator_type_to_string(copy_op->type));
589     return xnn_status_invalid_parameter;
590   }
591   copy_op->state = xnn_run_state_invalid;
592 
593   return setup_unary_elementwise_nc(
594     copy_op,
595     batch_size, input, output,
596     2 /* log2(sizeof(uint32_t)) */,
597     NULL, 0);
598 }
599 
xnn_setup_elu_nc_f32(xnn_operator_t elu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)600 enum xnn_status xnn_setup_elu_nc_f32(
601     xnn_operator_t elu_op,
602     size_t batch_size,
603     const float* input,
604     float* output,
605     pthreadpool_t threadpool)
606 {
607   if (elu_op->type != xnn_operator_type_elu_nc_f32) {
608     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
609       xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32),
610       xnn_operator_type_to_string(elu_op->type));
611     return xnn_status_invalid_parameter;
612   }
613   elu_op->state = xnn_run_state_invalid;
614 
615   return setup_unary_elementwise_nc(
616     elu_op,
617     batch_size, input, output,
618     2 /* log2(sizeof(float)) */,
619     &elu_op->params.f32_elu, sizeof(elu_op->params.f32_elu));
620 }
621 
xnn_setup_floor_nc_f32(xnn_operator_t floor_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)622 enum xnn_status xnn_setup_floor_nc_f32(
623     xnn_operator_t floor_op,
624     size_t batch_size,
625     const float* input,
626     float* output,
627     pthreadpool_t threadpool)
628 {
629   if (floor_op->type != xnn_operator_type_floor_nc_f32) {
630     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
631       xnn_operator_type_to_string(xnn_operator_type_floor_nc_f32),
632       xnn_operator_type_to_string(floor_op->type));
633     return xnn_status_invalid_parameter;
634   }
635   floor_op->state = xnn_run_state_invalid;
636 
637   return setup_unary_elementwise_nc(
638     floor_op,
639     batch_size, input, output,
640     2 /* log2(sizeof(float)) */,
641     &floor_op->params.f32_rnd, sizeof(floor_op->params.f32_rnd));
642 }
643 
xnn_setup_hardswish_nc_f16(xnn_operator_t hardswish_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)644 enum xnn_status xnn_setup_hardswish_nc_f16(
645     xnn_operator_t hardswish_op,
646     size_t batch_size,
647     const void* input,
648     void* output,
649     pthreadpool_t threadpool)
650 {
651   if (hardswish_op->type != xnn_operator_type_hardswish_nc_f16) {
652     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
653       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16),
654       xnn_operator_type_to_string(hardswish_op->type));
655     return xnn_status_invalid_parameter;
656   }
657   hardswish_op->state = xnn_run_state_invalid;
658 
659   return setup_unary_elementwise_nc(
660     hardswish_op,
661     batch_size, input, output,
662     1 /* log2(sizeof(half)) */,
663     &hardswish_op->params.f16_hswish, sizeof(hardswish_op->params.f16_hswish));
664 }
665 
xnn_setup_hardswish_nc_f32(xnn_operator_t hardswish_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)666 enum xnn_status xnn_setup_hardswish_nc_f32(
667     xnn_operator_t hardswish_op,
668     size_t batch_size,
669     const float* input,
670     float* output,
671     pthreadpool_t threadpool)
672 {
673   if (hardswish_op->type != xnn_operator_type_hardswish_nc_f32) {
674     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
675       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f32),
676       xnn_operator_type_to_string(hardswish_op->type));
677     return xnn_status_invalid_parameter;
678   }
679   hardswish_op->state = xnn_run_state_invalid;
680 
681   return setup_unary_elementwise_nc(
682     hardswish_op,
683     batch_size, input, output,
684     2 /* log2(sizeof(float)) */,
685     &hardswish_op->params.f32_hswish, sizeof(hardswish_op->params.f32_hswish));
686 }
687 
xnn_setup_leaky_relu_nc_f32(xnn_operator_t leaky_relu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)688 enum xnn_status xnn_setup_leaky_relu_nc_f32(
689   xnn_operator_t leaky_relu_op,
690   size_t batch_size,
691   const float* input,
692   float* output,
693   pthreadpool_t threadpool)
694 {
695   if (leaky_relu_op->type != xnn_operator_type_leaky_relu_nc_f32) {
696     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
697       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
698       xnn_operator_type_to_string(leaky_relu_op->type));
699     return xnn_status_invalid_parameter;
700   }
701   leaky_relu_op->state = xnn_run_state_invalid;
702 
703   return setup_unary_elementwise_nc(
704     leaky_relu_op,
705     batch_size, input, output,
706     2 /* log2(sizeof(float)) */,
707     &leaky_relu_op->params.f32_lrelu, sizeof(leaky_relu_op->params.f32_lrelu));
708 }
709 
xnn_setup_negate_nc_f32(xnn_operator_t negate_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)710 enum xnn_status xnn_setup_negate_nc_f32(
711     xnn_operator_t negate_op,
712     size_t batch_size,
713     const float* input,
714     float* output,
715     pthreadpool_t threadpool)
716 {
717   if (negate_op->type != xnn_operator_type_negate_nc_f32) {
718     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
719       xnn_operator_type_to_string(xnn_operator_type_negate_nc_f32),
720       xnn_operator_type_to_string(negate_op->type));
721     return xnn_status_invalid_parameter;
722   }
723   negate_op->state = xnn_run_state_invalid;
724 
725   return setup_unary_elementwise_nc(
726     negate_op,
727     batch_size, input, output,
728     2 /* log2(sizeof(float)) */,
729     &negate_op->params.f32_neg, sizeof(negate_op->params.f32_neg));
730 }
731 
xnn_setup_sigmoid_nc_f32(xnn_operator_t sigmoid_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)732 enum xnn_status xnn_setup_sigmoid_nc_f32(
733     xnn_operator_t sigmoid_op,
734     size_t batch_size,
735     const float* input,
736     float* output,
737     pthreadpool_t threadpool)
738 {
739   if (sigmoid_op->type != xnn_operator_type_sigmoid_nc_f32) {
740     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
741       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_f32),
742       xnn_operator_type_to_string(sigmoid_op->type));
743     return xnn_status_invalid_parameter;
744   }
745   sigmoid_op->state = xnn_run_state_invalid;
746 
747   return setup_unary_elementwise_nc(
748     sigmoid_op,
749     batch_size, input, output,
750     2 /* log2(sizeof(float)) */,
751     NULL, 0);
752 }
753 
xnn_setup_square_nc_f32(xnn_operator_t square_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)754 enum xnn_status xnn_setup_square_nc_f32(
755     xnn_operator_t square_op,
756     size_t batch_size,
757     const float* input,
758     float* output,
759     pthreadpool_t threadpool)
760 {
761   if (square_op->type != xnn_operator_type_square_nc_f32) {
762     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
763       xnn_operator_type_to_string(xnn_operator_type_square_nc_f32),
764       xnn_operator_type_to_string(square_op->type));
765     return xnn_status_invalid_parameter;
766   }
767   square_op->state = xnn_run_state_invalid;
768 
769   return setup_unary_elementwise_nc(
770     square_op,
771     batch_size, input, output,
772     2 /* log2(sizeof(float)) */,
773     NULL, 0);
774 }
775 
xnn_setup_square_root_nc_f32(xnn_operator_t sqrt_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)776 enum xnn_status xnn_setup_square_root_nc_f32(
777     xnn_operator_t sqrt_op,
778     size_t batch_size,
779     const float* input,
780     float* output,
781     pthreadpool_t threadpool)
782 {
783   if (sqrt_op->type != xnn_operator_type_square_root_nc_f32) {
784     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
785       xnn_operator_type_to_string(xnn_operator_type_square_root_nc_f32),
786       xnn_operator_type_to_string(sqrt_op->type));
787     return xnn_status_invalid_parameter;
788   }
789   sqrt_op->state = xnn_run_state_invalid;
790 
791   return setup_unary_elementwise_nc(
792     sqrt_op,
793     batch_size, input, output,
794     2 /* log2(sizeof(float)) */,
795     NULL, 0);
796 }
797 
xnn_setup_truncation_nc_f32(xnn_operator_t truncation_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)798 enum xnn_status xnn_setup_truncation_nc_f32(
799     xnn_operator_t truncation_op,
800     size_t batch_size,
801     const float* input,
802     float* output,
803     pthreadpool_t threadpool)
804 {
805   if (truncation_op->type != xnn_operator_type_truncation_nc_f32) {
806     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
807       xnn_operator_type_to_string(xnn_operator_type_truncation_nc_f32),
808       xnn_operator_type_to_string(truncation_op->type));
809     return xnn_status_invalid_parameter;
810   }
811   truncation_op->state = xnn_run_state_invalid;
812 
813   return setup_unary_elementwise_nc(
814     truncation_op,
815     batch_size, input, output,
816     2 /* log2(sizeof(float)) */,
817     &truncation_op->params.f32_rnd, sizeof(truncation_op->params.f32_rnd));
818 }