1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert KERNEL_TILE >= 2
8$assert ACCUMULATORS >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <xmmintrin.h>
13
14#include <xnnpack/dwconv.h>
15
16
17void xnn_f32_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__sse${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
18    size_t channels,
19    size_t output_width,
20    const float** input,
21    const float* weights,
22    float* output,
23    size_t input_stride,
24    size_t output_increment,
25    size_t input_offset,
26    const float* zero,
27    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
28{
29  assert(channels != 0);
30  assert(output_width != 0);
31
32  const __m128 vmax = _mm_load_ps(params->sse.max);
33  const __m128 vmin = _mm_load_ps(params->sse.min);
34  do {
35    $for K in range(KERNEL_TILE):
36      const float* i${K} = input[${K}];
37      assert(i${K} != NULL);
38      if XNN_UNPREDICTABLE(i${K} != zero) {
39        i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
40      }
41    input = (const float**) ((uintptr_t) input + input_stride);
42
43    size_t c = channels;
44    const float* w = weights;
45    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
46      __m128 vacc${ABC[0:4]}p0 = _mm_load_ps(w);
47      $for C in range(4, CHANNEL_TILE, 4):
48        __m128 vacc${ABC[C:C+4]}p0 = _mm_load_ps(w + ${C});
49
50      $for K in range(KERNEL_TILE):
51
52        const __m128 vi${K}x${ABC[0:4]} = _mm_loadu_ps(i${K});
53        $for C in range(4, CHANNEL_TILE, 4):
54          const __m128 vi${K}x${ABC[C:C+4]} = _mm_loadu_ps(i${K} + ${C});
55        i${K} += ${CHANNEL_TILE};
56
57        $for C in range(0, CHANNEL_TILE, 4):
58          const __m128 vk${K}x${ABC[C:C+4]} = _mm_load_ps(w + ${(K + 1) * CHANNEL_TILE + C});
59        $for C in range(0, CHANNEL_TILE, 4):
60          $if 1 <= K < ACCUMULATORS:
61            __m128 vacc${ABC[C:C+4]}p${K} = _mm_mul_ps(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]});
62          $else:
63            vacc${ABC[C:C+4]}p${K % ACCUMULATORS} = _mm_add_ps(vacc${ABC[C:C+4]}p${K % ACCUMULATORS}, _mm_mul_ps(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]}));
64
65      w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};
66
67      $if ACCUMULATORS > 1:
68        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
69        $ACC_SLICE = 1
70        $while ACC_SLICE < ACCUMULATORS:
71          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
72            $if A + ACC_SLICE < ACCUMULATORS:
73              $for C in range(0, CHANNEL_TILE, 4):
74                vacc${ABC[C:C+4]}p${A} = _mm_add_ps(vacc${ABC[C:C+4]}p${A}, vacc${ABC[C:C+4]}p${A + ACC_SLICE});
75          $ACC_SLICE *= 2
76
77      $for C in range(0, CHANNEL_TILE, 4):
78        __m128 vacc${ABC[C:C+4]} = _mm_max_ps(vacc${ABC[C:C+4]}p0, vmin);
79      $for C in range(0, CHANNEL_TILE, 4):
80        vacc${ABC[C:C+4]} = _mm_min_ps(vacc${ABC[C:C+4]}, vmax);
81
82      _mm_storeu_ps(output, vacc${ABC[0:4]});
83      $for C in range(4, CHANNEL_TILE, 4):
84        _mm_storeu_ps(output + ${C}, vacc${ABC[C:C+4]});
85      output += ${CHANNEL_TILE};
86    }
87    $if CHANNEL_TILE > 4:
88      for (; c >= 4; c -= 4) {
89        __m128 vacc0123p0 = _mm_load_ps(w);
90        $for K in range(KERNEL_TILE):
91
92          const __m128 vi${K}x0123 = _mm_loadu_ps(i${K});
93          i${K} += 4;
94
95          const __m128 vk${K}x0123 = _mm_load_ps(w + ${(K + 1) * CHANNEL_TILE});
96          $if 1 <= K < ACCUMULATORS:
97            __m128 vacc0123p${K} = _mm_mul_ps(vi${K}x0123, vk${K}x0123);
98          $else:
99            vacc0123p${K % ACCUMULATORS} = _mm_add_ps(vacc0123p${K % ACCUMULATORS}, _mm_mul_ps(vi${K}x0123, vk${K}x0123));
100
101        w += 4;
102
103        $if ACCUMULATORS > 1:
104          // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
105          $ACC_SLICE = 1
106          $while ACC_SLICE < ACCUMULATORS:
107            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
108              $if A + ACC_SLICE < ACCUMULATORS:
109                vacc0123p${A} = _mm_add_ps(vacc0123p${A}, vacc0123p${A + ACC_SLICE});
110            $ACC_SLICE *= 2
111
112        __m128 vacc0123 = _mm_max_ps(vacc0123p0, vmin);
113        vacc0123 = _mm_min_ps(vacc0123, vmax);
114
115        _mm_storeu_ps(output, vacc0123);
116        output += 4;
117      }
118    if XNN_UNLIKELY(c != 0) {
119      __m128 vacc0123p0 = _mm_load_ps(w);
120      $for K in range(KERNEL_TILE):
121
122        const __m128 vi${K}x0123 = _mm_loadu_ps(i${K});
123        const __m128 vk${K}x0123 = _mm_load_ps(w + ${(K + 1) * CHANNEL_TILE});
124        $if 1 <= K < ACCUMULATORS:
125          __m128 vacc0123p${K} = _mm_mul_ps(vi${K}x0123, vk${K}x0123);
126        $else:
127          vacc0123p${K % ACCUMULATORS} = _mm_add_ps(vacc0123p${K % ACCUMULATORS}, _mm_mul_ps(vi${K}x0123, vk${K}x0123));
128
129      $if ACCUMULATORS > 1:
130        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
131        $ACC_SLICE = 1
132        $while ACC_SLICE < ACCUMULATORS:
133          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
134            $if A + ACC_SLICE < ACCUMULATORS:
135              vacc0123p${A} = _mm_add_ps(vacc0123p${A}, vacc0123p${A + ACC_SLICE});
136          $ACC_SLICE *= 2
137
138      __m128 vacc0123 = _mm_max_ps(vacc0123p0, vmin);
139      vacc0123 = _mm_min_ps(vacc0123, vmax);
140
141      if (c & 2) {
142        _mm_storel_pi((__m64*) output, vacc0123);
143        vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
144        output += 2;
145      }
146      if (c & 1) {
147        _mm_store_ss(output, vacc0123);
148        output += 1;
149      }
150    }
151
152    output = (float*) ((uintptr_t) output + output_increment);
153  } while (--output_width != 0);
154}
155