1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdint.h>
18 #include <stdbool.h>
19 #include <string.h>
20 #include <nanohub/rsa.h>
21 
22 
biModIterative(uint32_t * num,const uint32_t * denum,uint32_t * tmp,uint32_t * state1,uint32_t * state2,uint32_t step)23 static bool biModIterative(uint32_t *num, const uint32_t *denum, uint32_t *tmp, uint32_t *state1, uint32_t *state2, uint32_t step)
24 //num %= denum where num is RSA_LEN * 2 and denum is RSA_LEN and tmp is RSA_LEN + limb_sz
25 //will need to be called till it returns true (up to RSA_LEN * 2 + 2 times)
26 {
27     uint32_t bitsh = *state1, limbsh = *state2;
28     bool ret = false;
29     int64_t t;
30     int32_t i;
31 
32     //first step is init
33     if (!step) {
34         //initially set it up left shifted as far as possible
35         memcpy(tmp + 1, denum, RSA_BYTES);
36         tmp[0] = 0;
37         bitsh = 32;
38         limbsh = RSA_LIMBS - 1;
39         goto out;
40     }
41 
42     //second is shifting denum
43     if (step == 1) {
44         while (!(tmp[RSA_LIMBS] & 0x80000000)) {
45             for (i = RSA_LIMBS; i > 0; i--) {
46                 tmp[i] <<= 1;
47                 if (tmp[i - 1] & 0x80000000)
48                     tmp[i]++;
49             }
50             //no need to adjust tmp[0] as it is still zero
51             bitsh++;
52         }
53         goto out;
54     }
55 
56     //all future steps do the division
57 
58     //check if we should subtract (uses less space than subtracting and unroling it later)
59     for (i = RSA_LIMBS; i >= 0; i--) {
60         if (num[limbsh + i] < tmp[i])
61             goto dont_subtract;
62         if (num[limbsh + i] > tmp[i])
63             break;
64     }
65 
66     //subtract
67     t = 0;
68     for (i = 0; i <= RSA_LIMBS; i++) {
69         t += (uint64_t)num[limbsh + i];
70         t -= (uint64_t)tmp[i];
71         num[limbsh + i] = t;
72         t >>= 32;
73     }
74 
75     //carry the subtraction's carry to the end
76     for (i = RSA_LIMBS + limbsh + 1; i < RSA_LIMBS * 2; i++) {
77         t += (uint64_t)num[i];
78         num[i] = t;
79         t >>= 32;
80     }
81 
82 dont_subtract:
83     //handle bitshifts/refills
84     if (!bitsh) {                          // tmp = denum << 32
85         if (!limbsh) {
86             ret = true;
87             goto out;
88         }
89 
90         memcpy(tmp + 1, denum, RSA_BYTES);
91         tmp[0] = 0;
92         bitsh = 32;
93         limbsh--;
94     }
95     else {                                 // tmp >>= 1
96         for (i = 0; i < RSA_LIMBS; i++) {
97             tmp[i] >>= 1;
98             if (tmp[i + 1] & 1)
99                 tmp[i] += 0x80000000;
100         }
101         tmp[i] >>= 1;
102         bitsh--;
103     }
104 
105 
106 out:
107     *state1 = bitsh;
108     *state2 = limbsh;
109     return ret;
110 }
111 
biMulIterative(uint32_t * ret,const uint32_t * a,const uint32_t * b,uint32_t step)112 static void biMulIterative(uint32_t *ret, const uint32_t *a, const uint32_t *b, uint32_t step) //ret = a * b, call with step = [0..RSA_LIMBS)
113 {
114     uint32_t j, c;
115     uint64_t r;
116 
117     //zero the result on first call
118     if (!step)
119         memset(ret, 0, RSA_BYTES * 2);
120 
121     //produce a partial sum & add it in
122     c = 0;
123     for (j = 0; j < RSA_LIMBS; j++) {
124         r = (uint64_t)a[step] * b[j] + c + ret[step + j];
125         ret[step + j] = r;
126         c = r >> 32;
127     }
128 
129     //carry the carry to the end
130     for (j = step + RSA_LIMBS; j < RSA_LIMBS * 2; j++) {
131         r = (uint64_t)ret[j] + c;
132         ret[j] = r;
133         c = r >> 32;
134     }
135 }
136 
137 /*
138  * Piecewise RSA:
139  * normal RSA public op with 65537 exponent does 34 operations. 17 muls and 17 mods, as follows:
140  * 16x {mul, mod} to calculate a ^ 65536 mod c
141  * 1x {mul, mod} to calculate a ^ 65537 mod c
142  * we break up each mul and mod itself into more steps. mul needs RSA_LIMBS steps, and mod needs up to RSA_LEN * 2 + 2 steps
143  * so if we allocate RSA_LEN * 3 step values to mod, each mul-mod pair will use <= RSA_LEN * 4 step values
144  * and the whole opetaion will need <= RSA_LEN * 4 * 34 step values, which fits into a uint32. cool. In fact
145  * some values will be skipped, but this makes life easier, really. Call this func with *stepP = 0, and keep calling till
146  * output stepP is zero. We'll call each of the RSA_LEN * 4 pieces a gigastep, and have 17 of them as seen above. Each
147  * will be logically separated into 4 megasteps. First will contain the MUL, last 3 the MOD and maybe the memcpy.
148  * In the first 16 gigasteps, the very last step of the gigastep will be used for the memcpy call.
149  *
150  * The initial non-iterative RSA logic looks as follows, shown here for clarity:
151  *
152  *   memcpy(state->tmpB, a, RSA_BYTES);
153  *   for (i = 0; i < 16; i++) {
154  *       biMul(state->tmpA, state->tmpB, state->tmpB);
155  *       biMod(state->tmpA, c, state->tmpB);
156  *       memcpy(state->tmpB, state->tmpA, RSA_BYTES);
157  *   }
158  *
159  *   //calculate a ^ 65537 mod c into state->tmpA [ at this point this means do state->tmpA = (state->tmpB * a) % c ]
160  *   biMul(state->tmpA, state->tmpB, a);
161  *   biMod(state->tmpA, c, state->tmpB);
162  *
163  *   //return result
164  *   return state->tmpA;
165  *
166  */
167 
rsaPubOpIterative(struct RsaState * state,const uint32_t * a,const uint32_t * c,uint32_t * state1,uint32_t * state2,uint32_t * stepP)168 const uint32_t* rsaPubOpIterative(struct RsaState* state, const uint32_t *a, const uint32_t *c, uint32_t *state1, uint32_t *state2, uint32_t *stepP)
169 {
170     uint32_t step = *stepP, gigastep, gigastepBase, gigastepSubstep, megaSubstep;
171 
172     //step 0: copy a -> tmpB
173     if (!step) {
174         memcpy(state->tmpB, a, RSA_BYTES);
175         step = 1;
176     }
177     else { //subsequent steps: do real work
178 
179 
180         gigastep = (step - 1) / (RSA_LEN * 4);
181         gigastepSubstep = (step - 1) % (RSA_LEN * 4);
182         gigastepBase = gigastep * (RSA_LEN * 4);
183         megaSubstep = gigastepSubstep / RSA_LEN;
184 
185         if (!megaSubstep) { // first megastep of the gigastep - MUL
186             biMulIterative(state->tmpA, state->tmpB, gigastep == 16 ? a : state->tmpB, gigastepSubstep);
187             if (gigastepSubstep == RSA_LIMBS - 1) //MUL is done - do mod next
188                 step = gigastepBase + RSA_LEN + 1;
189             else                                  //More of MUL is left to do
190                 step++;
191         }
192         else if (gigastepSubstep != RSA_LEN * 4 - 1){   // second part of gigastep - MOD
193             if (biModIterative(state->tmpA, c, state->tmpB, state1, state2, gigastepSubstep - RSA_LEN)) { //MOD is done
194                 if (gigastep == 16) // we're done
195                     step = 0;
196                 else              // last part of the gigastep is a copy
197                     step = gigastepBase + RSA_LEN * 4 - 1 + 1;
198             }
199             else
200                 step++;
201         }
202         else {   //last part - memcpy
203             memcpy(state->tmpB, state->tmpA, RSA_BYTES);
204             step++;
205         }
206     }
207 
208     *stepP = step;
209     return state->tmpA;
210 }
211 
212 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM) || defined (RSA_SUPPORT_PRIV_OP_BIGRAM)
213 #include <stdio.h>
rsaPubOp(struct RsaState * state,const uint32_t * a,const uint32_t * c)214 const uint32_t* rsaPubOp(struct RsaState* state, const uint32_t *a, const uint32_t *c)
215 {
216     const uint32_t *ret;
217     uint32_t state1 = 0, state2 = 0, step = 0, ns = 0;
218 
219     do {
220         ret = rsaPubOpIterative(state, a, c, &state1, &state2, &step);
221         ns++;
222     } while(step);
223 
224 fprintf(stderr, "steps: %u\n", ns);
225 
226     return ret;
227 }
228 
biMod(uint32_t * num,const uint32_t * denum,uint32_t * tmp)229 static void biMod(uint32_t *num, const uint32_t *denum, uint32_t *tmp)
230 {
231     uint32_t state1 = 0, state2 = 0, step;
232 
233     for (step = 0; !biModIterative(num, denum, tmp, &state1, &state2, step); step++);
234 }
235 
biMul(uint32_t * ret,const uint32_t * a,const uint32_t * b)236 static void biMul(uint32_t *ret, const uint32_t *a, const uint32_t *b)
237 {
238     uint32_t step;
239 
240     for (step = 0; step < RSA_LIMBS; step++)
241         biMulIterative(ret, a, b, step);
242 }
243 
rsaPrivOp(struct RsaState * state,const uint32_t * a,const uint32_t * b,const uint32_t * c)244 const uint32_t* rsaPrivOp(struct RsaState* state, const uint32_t *a, const uint32_t *b, const uint32_t *c)
245 {
246     uint32_t i;
247 
248     memcpy(state->tmpC, a, RSA_BYTES);  //tC will hold our powers of a
249 
250     memset(state->tmpA, 0, RSA_BYTES * 2); //tA will hold result
251     state->tmpA[0] = 1;
252 
253     for (i = 0; i < RSA_LEN; i++) {
254         //if the bit is set, multiply the current power of A into result
255         if (b[i / 32] & (1 << (i % 32))) {
256             memcpy(state->tmpB, state->tmpA, RSA_BYTES);
257             biMul(state->tmpA, state->tmpB, state->tmpC);
258             biMod(state->tmpA, c, state->tmpB);
259         }
260 
261         //calculate the next power of a and modulus it
262 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM)
263         memcpy(state->tmpB, state->tmpA, RSA_BYTES); //save tA
264         biMul(state->tmpA, state->tmpC, state->tmpC);
265         biMod(state->tmpA, c, state->tmpC);
266         memcpy(state->tmpC, state->tmpA, RSA_BYTES);
267         memcpy(state->tmpA, state->tmpB, RSA_BYTES); //restore tA
268 #elif defined (RSA_SUPPORT_PRIV_OP_BIGRAM)
269         memcpy(state->tmpB, state->tmpC, RSA_BYTES);
270         biMul(state->tmpC, state->tmpB, state->tmpB);
271         biMod(state->tmpC, c, state->tmpB);
272 #endif
273     }
274 
275     return state->tmpA;
276 }
277 #endif
278 
279 
280 
281 
282 
283 
284 
285 
286