1 /*
2  * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  */
27 
28 
29 #include "mpdecimal.h"
30 
31 #include <assert.h>
32 
33 #include "constants.h"
34 #include "fourstep.h"
35 #include "numbertheory.h"
36 #include "sixstep.h"
37 #include "umodarith.h"
38 
39 
40 /* Bignum: Cache efficient Matrix Fourier Transform for arrays of the
41    form 3 * 2**n (See literature/matrix-transform.txt). */
42 
43 
44 #ifndef PPRO
45 static inline void
std_size3_ntt(mpd_uint_t * x1,mpd_uint_t * x2,mpd_uint_t * x3,mpd_uint_t w3table[3],mpd_uint_t umod)46 std_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3,
47               mpd_uint_t w3table[3], mpd_uint_t umod)
48 {
49     mpd_uint_t r1, r2;
50     mpd_uint_t w;
51     mpd_uint_t s, tmp;
52 
53 
54     /* k = 0 -> w = 1 */
55     s = *x1;
56     s = addmod(s, *x2, umod);
57     s = addmod(s, *x3, umod);
58 
59     r1 = s;
60 
61     /* k = 1 */
62     s = *x1;
63 
64     w = w3table[1];
65     tmp = MULMOD(*x2, w);
66     s = addmod(s, tmp, umod);
67 
68     w = w3table[2];
69     tmp = MULMOD(*x3, w);
70     s = addmod(s, tmp, umod);
71 
72     r2 = s;
73 
74     /* k = 2 */
75     s = *x1;
76 
77     w = w3table[2];
78     tmp = MULMOD(*x2, w);
79     s = addmod(s, tmp, umod);
80 
81     w = w3table[1];
82     tmp = MULMOD(*x3, w);
83     s = addmod(s, tmp, umod);
84 
85     *x3 = s;
86     *x2 = r2;
87     *x1 = r1;
88 }
89 #else /* PPRO */
90 static inline void
ppro_size3_ntt(mpd_uint_t * x1,mpd_uint_t * x2,mpd_uint_t * x3,mpd_uint_t w3table[3],mpd_uint_t umod,double * dmod,uint32_t dinvmod[3])91 ppro_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3, mpd_uint_t w3table[3],
92                mpd_uint_t umod, double *dmod, uint32_t dinvmod[3])
93 {
94     mpd_uint_t r1, r2;
95     mpd_uint_t w;
96     mpd_uint_t s, tmp;
97 
98 
99     /* k = 0 -> w = 1 */
100     s = *x1;
101     s = addmod(s, *x2, umod);
102     s = addmod(s, *x3, umod);
103 
104     r1 = s;
105 
106     /* k = 1 */
107     s = *x1;
108 
109     w = w3table[1];
110     tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
111     s = addmod(s, tmp, umod);
112 
113     w = w3table[2];
114     tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
115     s = addmod(s, tmp, umod);
116 
117     r2 = s;
118 
119     /* k = 2 */
120     s = *x1;
121 
122     w = w3table[2];
123     tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
124     s = addmod(s, tmp, umod);
125 
126     w = w3table[1];
127     tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
128     s = addmod(s, tmp, umod);
129 
130     *x3 = s;
131     *x2 = r2;
132     *x1 = r1;
133 }
134 #endif
135 
136 
137 /* forward transform, sign = -1; transform length = 3 * 2**n */
138 int
four_step_fnt(mpd_uint_t * a,mpd_size_t n,int modnum)139 four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
140 {
141     mpd_size_t R = 3; /* number of rows */
142     mpd_size_t C = n / 3; /* number of columns */
143     mpd_uint_t w3table[3];
144     mpd_uint_t kernel, w0, w1, wstep;
145     mpd_uint_t *s, *p0, *p1, *p2;
146     mpd_uint_t umod;
147 #ifdef PPRO
148     double dmod;
149     uint32_t dinvmod[3];
150 #endif
151     mpd_size_t i, k;
152 
153 
154     assert(n >= 48);
155     assert(n <= 3*MPD_MAXTRANSFORM_2N);
156 
157 
158     /* Length R transform on the columns. */
159     SETMODULUS(modnum);
160     _mpd_init_w3table(w3table, -1, modnum);
161     for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
162 
163         SIZE3_NTT(p0, p1, p2, w3table);
164     }
165 
166     /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
167     kernel = _mpd_getkernel(n, -1, modnum);
168     for (i = 1; i < R; i++) {
169         w0 = 1;                  /* r**(i*0): initial value for k=0 */
170         w1 = POWMOD(kernel, i);  /* r**(i*1): initial value for k=1 */
171         wstep = MULMOD(w1, w1);  /* r**(2*i) */
172         for (k = 0; k < C-1; k += 2) {
173             mpd_uint_t x0 = a[i*C+k];
174             mpd_uint_t x1 = a[i*C+k+1];
175             MULMOD2(&x0, w0, &x1, w1);
176             MULMOD2C(&w0, &w1, wstep);  /* r**(i*(k+2)) = r**(i*k) * r**(2*i) */
177             a[i*C+k] = x0;
178             a[i*C+k+1] = x1;
179         }
180     }
181 
182     /* Length C transform on the rows. */
183     for (s = a; s < a+n; s += C) {
184         if (!six_step_fnt(s, C, modnum)) {
185             return 0;
186         }
187     }
188 
189 #if 0
190     /* An unordered transform is sufficient for convolution. */
191     /* Transpose the matrix. */
192     #include "transpose.h"
193     transpose_3xpow2(a, R, C);
194 #endif
195 
196     return 1;
197 }
198 
199 /* backward transform, sign = 1; transform length = 3 * 2**n */
200 int
inv_four_step_fnt(mpd_uint_t * a,mpd_size_t n,int modnum)201 inv_four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
202 {
203     mpd_size_t R = 3; /* number of rows */
204     mpd_size_t C = n / 3; /* number of columns */
205     mpd_uint_t w3table[3];
206     mpd_uint_t kernel, w0, w1, wstep;
207     mpd_uint_t *s, *p0, *p1, *p2;
208     mpd_uint_t umod;
209 #ifdef PPRO
210     double dmod;
211     uint32_t dinvmod[3];
212 #endif
213     mpd_size_t i, k;
214 
215 
216     assert(n >= 48);
217     assert(n <= 3*MPD_MAXTRANSFORM_2N);
218 
219 
220 #if 0
221     /* An unordered transform is sufficient for convolution. */
222     /* Transpose the matrix, producing an R*C matrix. */
223     #include "transpose.h"
224     transpose_3xpow2(a, C, R);
225 #endif
226 
227     /* Length C transform on the rows. */
228     for (s = a; s < a+n; s += C) {
229         if (!inv_six_step_fnt(s, C, modnum)) {
230             return 0;
231         }
232     }
233 
234     /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
235     SETMODULUS(modnum);
236     kernel = _mpd_getkernel(n, 1, modnum);
237     for (i = 1; i < R; i++) {
238         w0 = 1;
239         w1 = POWMOD(kernel, i);
240         wstep = MULMOD(w1, w1);
241         for (k = 0; k < C; k += 2) {
242             mpd_uint_t x0 = a[i*C+k];
243             mpd_uint_t x1 = a[i*C+k+1];
244             MULMOD2(&x0, w0, &x1, w1);
245             MULMOD2C(&w0, &w1, wstep);
246             a[i*C+k] = x0;
247             a[i*C+k+1] = x1;
248         }
249     }
250 
251     /* Length R transform on the columns. */
252     _mpd_init_w3table(w3table, 1, modnum);
253     for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
254 
255         SIZE3_NTT(p0, p1, p2, w3table);
256     }
257 
258     return 1;
259 }
260