1 package org.bouncycastle.math.ec;
2 
3 import java.math.BigInteger;
4 
5 public abstract class WNafUtil
6 {
7     public static final String PRECOMP_NAME = "bc_wnaf";
8 
9     private static final int[] DEFAULT_WINDOW_SIZE_CUTOFFS = new int[]{ 13, 41, 121, 337, 897, 2305 };
10 
11     private static final byte[] EMPTY_BYTES = new byte[0];
12     private static final int[] EMPTY_INTS = new int[0];
13     private static final ECPoint[] EMPTY_POINTS = new ECPoint[0];
14 
generateCompactNaf(BigInteger k)15     public static int[] generateCompactNaf(BigInteger k)
16     {
17         if ((k.bitLength() >>> 16) != 0)
18         {
19             throw new IllegalArgumentException("'k' must have bitlength < 2^16");
20         }
21         if (k.signum() == 0)
22         {
23             return EMPTY_INTS;
24         }
25 
26         BigInteger _3k = k.shiftLeft(1).add(k);
27 
28         int bits = _3k.bitLength();
29         int[] naf = new int[bits >> 1];
30 
31         BigInteger diff = _3k.xor(k);
32 
33         int highBit = bits - 1, length = 0, zeroes = 0;
34         for (int i = 1; i < highBit; ++i)
35         {
36             if (!diff.testBit(i))
37             {
38                 ++zeroes;
39                 continue;
40             }
41 
42             int digit  = k.testBit(i) ? -1 : 1;
43             naf[length++] = (digit << 16) | zeroes;
44             zeroes = 1;
45             ++i;
46         }
47 
48         naf[length++] = (1 << 16) | zeroes;
49 
50         if (naf.length > length)
51         {
52             naf = trim(naf, length);
53         }
54 
55         return naf;
56     }
57 
generateCompactWindowNaf(int width, BigInteger k)58     public static int[] generateCompactWindowNaf(int width, BigInteger k)
59     {
60         if (width == 2)
61         {
62             return generateCompactNaf(k);
63         }
64 
65         if (width < 2 || width > 16)
66         {
67             throw new IllegalArgumentException("'width' must be in the range [2, 16]");
68         }
69         if ((k.bitLength() >>> 16) != 0)
70         {
71             throw new IllegalArgumentException("'k' must have bitlength < 2^16");
72         }
73         if (k.signum() == 0)
74         {
75             return EMPTY_INTS;
76         }
77 
78         int[] wnaf = new int[k.bitLength() / width + 1];
79 
80         // 2^width and a mask and sign bit set accordingly
81         int pow2 = 1 << width;
82         int mask = pow2 - 1;
83         int sign = pow2 >>> 1;
84 
85         boolean carry = false;
86         int length = 0, pos = 0;
87 
88         while (pos <= k.bitLength())
89         {
90             if (k.testBit(pos) == carry)
91             {
92                 ++pos;
93                 continue;
94             }
95 
96             k = k.shiftRight(pos);
97 
98             int digit = k.intValue() & mask;
99             if (carry)
100             {
101                 ++digit;
102             }
103 
104             carry = (digit & sign) != 0;
105             if (carry)
106             {
107                 digit -= pow2;
108             }
109 
110             int zeroes = length > 0 ? pos - 1 : pos;
111             wnaf[length++] = (digit << 16) | zeroes;
112             pos = width;
113         }
114 
115         // Reduce the WNAF array to its actual length
116         if (wnaf.length > length)
117         {
118             wnaf = trim(wnaf, length);
119         }
120 
121         return wnaf;
122     }
123 
generateJSF(BigInteger g, BigInteger h)124     public static byte[] generateJSF(BigInteger g, BigInteger h)
125     {
126         int digits = Math.max(g.bitLength(), h.bitLength()) + 1;
127         byte[] jsf = new byte[digits];
128 
129         BigInteger k0 = g, k1 = h;
130         int j = 0, d0 = 0, d1 = 0;
131 
132         int offset = 0;
133         while ((d0 | d1) != 0 || k0.bitLength() > offset || k1.bitLength() > offset)
134         {
135             int n0 = ((k0.intValue() >>> offset) + d0) & 7, n1 = ((k1.intValue() >>> offset) + d1) & 7;
136 
137             int u0 = n0 & 1;
138             if (u0 != 0)
139             {
140                 u0 -= (n0 & 2);
141                 if ((n0 + u0) == 4 && (n1 & 3) == 2)
142                 {
143                     u0 = -u0;
144                 }
145             }
146 
147             int u1 = n1 & 1;
148             if (u1 != 0)
149             {
150                 u1 -= (n1 & 2);
151                 if ((n1 + u1) == 4 && (n0 & 3) == 2)
152                 {
153                     u1 = -u1;
154                 }
155             }
156 
157             if ((d0 << 1) == 1 + u0)
158             {
159                 d0 ^= 1;
160             }
161             if ((d1 << 1) == 1 + u1)
162             {
163                 d1 ^= 1;
164             }
165 
166             if (++offset == 30)
167             {
168                 offset = 0;
169                 k0 = k0.shiftRight(30);
170                 k1 = k1.shiftRight(30);
171             }
172 
173             jsf[j++] = (byte)((u0 << 4) | (u1 & 0xF));
174         }
175 
176         // Reduce the JSF array to its actual length
177         if (jsf.length > j)
178         {
179             jsf = trim(jsf, j);
180         }
181 
182         return jsf;
183     }
184 
generateNaf(BigInteger k)185     public static byte[] generateNaf(BigInteger k)
186     {
187         if (k.signum() == 0)
188         {
189             return EMPTY_BYTES;
190         }
191 
192         BigInteger _3k = k.shiftLeft(1).add(k);
193 
194         int digits = _3k.bitLength() - 1;
195         byte[] naf = new byte[digits];
196 
197         BigInteger diff = _3k.xor(k);
198 
199         for (int i = 1; i < digits; ++i)
200         {
201             if (diff.testBit(i))
202             {
203                 naf[i - 1] = (byte)(k.testBit(i) ? -1 : 1);
204                 ++i;
205             }
206         }
207 
208         naf[digits - 1] = 1;
209 
210         return naf;
211     }
212 
213     /**
214      * Computes the Window NAF (non-adjacent Form) of an integer.
215      * @param width The width <code>w</code> of the Window NAF. The width is
216      * defined as the minimal number <code>w</code>, such that for any
217      * <code>w</code> consecutive digits in the resulting representation, at
218      * most one is non-zero.
219      * @param k The integer of which the Window NAF is computed.
220      * @return The Window NAF of the given width, such that the following holds:
221      * <code>k = &sum;<sub>i=0</sub><sup>l-1</sup> k<sub>i</sub>2<sup>i</sup>
222      * </code>, where the <code>k<sub>i</sub></code> denote the elements of the
223      * returned <code>byte[]</code>.
224      */
generateWindowNaf(int width, BigInteger k)225     public static byte[] generateWindowNaf(int width, BigInteger k)
226     {
227         if (width == 2)
228         {
229             return generateNaf(k);
230         }
231 
232         if (width < 2 || width > 8)
233         {
234             throw new IllegalArgumentException("'width' must be in the range [2, 8]");
235         }
236         if (k.signum() == 0)
237         {
238             return EMPTY_BYTES;
239         }
240 
241         byte[] wnaf = new byte[k.bitLength() + 1];
242 
243         // 2^width and a mask and sign bit set accordingly
244         int pow2 = 1 << width;
245         int mask = pow2 - 1;
246         int sign = pow2 >>> 1;
247 
248         boolean carry = false;
249         int length = 0, pos = 0;
250 
251         while (pos <= k.bitLength())
252         {
253             if (k.testBit(pos) == carry)
254             {
255                 ++pos;
256                 continue;
257             }
258 
259             k = k.shiftRight(pos);
260 
261             int digit = k.intValue() & mask;
262             if (carry)
263             {
264                 ++digit;
265             }
266 
267             carry = (digit & sign) != 0;
268             if (carry)
269             {
270                 digit -= pow2;
271             }
272 
273             length += (length > 0) ? pos - 1 : pos;
274             wnaf[length++] = (byte)digit;
275             pos = width;
276         }
277 
278         // Reduce the WNAF array to its actual length
279         if (wnaf.length > length)
280         {
281             wnaf = trim(wnaf, length);
282         }
283 
284         return wnaf;
285     }
286 
getNafWeight(BigInteger k)287     public static int getNafWeight(BigInteger k)
288     {
289         if (k.signum() == 0)
290         {
291             return 0;
292         }
293 
294         BigInteger _3k = k.shiftLeft(1).add(k);
295         BigInteger diff = _3k.xor(k);
296 
297         return diff.bitCount();
298     }
299 
getWNafPreCompInfo(ECPoint p)300     public static WNafPreCompInfo getWNafPreCompInfo(ECPoint p)
301     {
302         return getWNafPreCompInfo(p.getCurve().getPreCompInfo(p, PRECOMP_NAME));
303     }
304 
getWNafPreCompInfo(PreCompInfo preCompInfo)305     public static WNafPreCompInfo getWNafPreCompInfo(PreCompInfo preCompInfo)
306     {
307         if ((preCompInfo != null) && (preCompInfo instanceof WNafPreCompInfo))
308         {
309             return (WNafPreCompInfo)preCompInfo;
310         }
311 
312         return new WNafPreCompInfo();
313     }
314 
315     /**
316      * Determine window width to use for a scalar multiplication of the given size.
317      *
318      * @param bits the bit-length of the scalar to multiply by
319      * @return the window size to use
320      */
getWindowSize(int bits)321     public static int getWindowSize(int bits)
322     {
323         return getWindowSize(bits, DEFAULT_WINDOW_SIZE_CUTOFFS);
324     }
325 
326     /**
327      * Determine window width to use for a scalar multiplication of the given size.
328      *
329      * @param bits the bit-length of the scalar to multiply by
330      * @param windowSizeCutoffs a monotonically increasing list of bit sizes at which to increment the window width
331      * @return the window size to use
332      */
getWindowSize(int bits, int[] windowSizeCutoffs)333     public static int getWindowSize(int bits, int[] windowSizeCutoffs)
334     {
335         int w = 0;
336         for (; w < windowSizeCutoffs.length; ++w)
337         {
338             if (bits < windowSizeCutoffs[w])
339             {
340                 break;
341             }
342         }
343         return w + 2;
344     }
345 
mapPointWithPrecomp(ECPoint p, int width, boolean includeNegated, ECPointMap pointMap)346     public static ECPoint mapPointWithPrecomp(ECPoint p, int width, boolean includeNegated,
347         ECPointMap pointMap)
348     {
349         ECCurve c = p.getCurve();
350         WNafPreCompInfo wnafPreCompP = precompute(p, width, includeNegated);
351 
352         ECPoint q = pointMap.map(p);
353         WNafPreCompInfo wnafPreCompQ = getWNafPreCompInfo(c.getPreCompInfo(q, PRECOMP_NAME));
354 
355         ECPoint twiceP = wnafPreCompP.getTwice();
356         if (twiceP != null)
357         {
358             ECPoint twiceQ = pointMap.map(twiceP);
359             wnafPreCompQ.setTwice(twiceQ);
360         }
361 
362         ECPoint[] preCompP = wnafPreCompP.getPreComp();
363         ECPoint[] preCompQ = new ECPoint[preCompP.length];
364         for (int i = 0; i < preCompP.length; ++i)
365         {
366             preCompQ[i] = pointMap.map(preCompP[i]);
367         }
368         wnafPreCompQ.setPreComp(preCompQ);
369 
370         if (includeNegated)
371         {
372             ECPoint[] preCompNegQ = new ECPoint[preCompQ.length];
373             for (int i = 0; i < preCompNegQ.length; ++i)
374             {
375                 preCompNegQ[i] = preCompQ[i].negate();
376             }
377             wnafPreCompQ.setPreCompNeg(preCompNegQ);
378         }
379 
380         c.setPreCompInfo(q, PRECOMP_NAME, wnafPreCompQ);
381 
382         return q;
383     }
384 
precompute(ECPoint p, int width, boolean includeNegated)385     public static WNafPreCompInfo precompute(ECPoint p, int width, boolean includeNegated)
386     {
387         ECCurve c = p.getCurve();
388         WNafPreCompInfo wnafPreCompInfo = getWNafPreCompInfo(c.getPreCompInfo(p, PRECOMP_NAME));
389 
390         int iniPreCompLen = 0, reqPreCompLen = 1 << Math.max(0, width - 2);
391 
392         ECPoint[] preComp = wnafPreCompInfo.getPreComp();
393         if (preComp == null)
394         {
395             preComp = EMPTY_POINTS;
396         }
397         else
398         {
399             iniPreCompLen = preComp.length;
400         }
401 
402         if (iniPreCompLen < reqPreCompLen)
403         {
404             preComp = resizeTable(preComp, reqPreCompLen);
405 
406             if (reqPreCompLen == 1)
407             {
408                 preComp[0] = p.normalize();
409             }
410             else
411             {
412                 int curPreCompLen = iniPreCompLen;
413                 if (curPreCompLen == 0)
414                 {
415                     preComp[0] = p;
416                     curPreCompLen = 1;
417                 }
418 
419                 ECFieldElement iso = null;
420 
421                 if (reqPreCompLen == 2)
422                 {
423                     preComp[1] = p.threeTimes();
424                 }
425                 else
426                 {
427                     ECPoint twiceP = wnafPreCompInfo.getTwice(), last = preComp[curPreCompLen - 1];
428                     if (twiceP == null)
429                     {
430                         twiceP = preComp[0].twice();
431                         wnafPreCompInfo.setTwice(twiceP);
432 
433                         /*
434                          * For Fp curves with Jacobian projective coordinates, use a (quasi-)isomorphism
435                          * where 'twiceP' is "affine", so that the subsequent additions are cheaper. This
436                          * also requires scaling the initial point's X, Y coordinates, and reversing the
437                          * isomorphism as part of the subsequent normalization.
438                          *
439                          *  NOTE: The correctness of this optimization depends on:
440                          *      1) additions do not use the curve's A, B coefficients.
441                          *      2) no special cases (i.e. Q +/- Q) when calculating 1P, 3P, 5P, ...
442                          */
443                         if (ECAlgorithms.isFpCurve(c) && c.getFieldSize() >= 64)
444                         {
445                             switch (c.getCoordinateSystem())
446                             {
447                             case ECCurve.COORD_JACOBIAN:
448                             case ECCurve.COORD_JACOBIAN_CHUDNOVSKY:
449                             case ECCurve.COORD_JACOBIAN_MODIFIED:
450                             {
451                                 iso = twiceP.getZCoord(0);
452                                 twiceP = c.createPoint(twiceP.getXCoord().toBigInteger(), twiceP.getYCoord()
453                                     .toBigInteger());
454 
455                                 ECFieldElement iso2 = iso.square(), iso3 = iso2.multiply(iso);
456                                 last = last.scaleX(iso2).scaleY(iso3);
457 
458                                 if (iniPreCompLen == 0)
459                                 {
460                                     preComp[0] = last;
461                                 }
462                                 break;
463                             }
464                             }
465                         }
466                     }
467 
468                     while (curPreCompLen < reqPreCompLen)
469                     {
470                         /*
471                          * Compute the new ECPoints for the precomputation array. The values 1, 3,
472                          * 5, ..., 2^(width-1)-1 times p are computed
473                          */
474                         preComp[curPreCompLen++] = last = last.add(twiceP);
475                     }
476                 }
477 
478                 /*
479                  * Having oft-used operands in affine form makes operations faster.
480                  */
481                 c.normalizeAll(preComp, iniPreCompLen, reqPreCompLen - iniPreCompLen, iso);
482             }
483         }
484 
485         wnafPreCompInfo.setPreComp(preComp);
486 
487         if (includeNegated)
488         {
489             ECPoint[] preCompNeg = wnafPreCompInfo.getPreCompNeg();
490 
491             int pos;
492             if (preCompNeg == null)
493             {
494                 pos = 0;
495                 preCompNeg = new ECPoint[reqPreCompLen];
496             }
497             else
498             {
499                 pos = preCompNeg.length;
500                 if (pos < reqPreCompLen)
501                 {
502                     preCompNeg = resizeTable(preCompNeg, reqPreCompLen);
503                 }
504             }
505 
506             while (pos < reqPreCompLen)
507             {
508                 preCompNeg[pos] = preComp[pos].negate();
509                 ++pos;
510             }
511 
512             wnafPreCompInfo.setPreCompNeg(preCompNeg);
513         }
514 
515         c.setPreCompInfo(p, PRECOMP_NAME, wnafPreCompInfo);
516 
517         return wnafPreCompInfo;
518     }
519 
trim(byte[] a, int length)520     private static byte[] trim(byte[] a, int length)
521     {
522         byte[] result = new byte[length];
523         System.arraycopy(a, 0, result, 0, result.length);
524         return result;
525     }
526 
trim(int[] a, int length)527     private static int[] trim(int[] a, int length)
528     {
529         int[] result = new int[length];
530         System.arraycopy(a, 0, result, 0, result.length);
531         return result;
532     }
533 
resizeTable(ECPoint[] a, int length)534     private static ECPoint[] resizeTable(ECPoint[] a, int length)
535     {
536         ECPoint[] result = new ECPoint[length];
537         System.arraycopy(a, 0, result, 0, a.length);
538         return result;
539     }
540 }
541