1 package org.bouncycastle.math.ec;
2 
3 import java.math.BigInteger;
4 
5 import org.bouncycastle.math.ec.endo.ECEndomorphism;
6 import org.bouncycastle.math.ec.endo.GLVEndomorphism;
7 import org.bouncycastle.math.field.FiniteField;
8 import org.bouncycastle.math.field.PolynomialExtensionField;
9 
10 public class ECAlgorithms
11 {
isF2mCurve(ECCurve c)12     public static boolean isF2mCurve(ECCurve c)
13     {
14         FiniteField field = c.getField();
15         return field.getDimension() > 1 && field.getCharacteristic().equals(ECConstants.TWO)
16             && field instanceof PolynomialExtensionField;
17     }
18 
isFpCurve(ECCurve c)19     public static boolean isFpCurve(ECCurve c)
20     {
21         return c.getField().getDimension() == 1;
22     }
23 
sumOfMultiplies(ECPoint[] ps, BigInteger[] ks)24     public static ECPoint sumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
25     {
26         if (ps == null || ks == null || ps.length != ks.length || ps.length < 1)
27         {
28             throw new IllegalArgumentException("point and scalar arrays should be non-null, and of equal, non-zero, length");
29         }
30 
31         int count = ps.length;
32         switch (count)
33         {
34         case 1:
35             return ps[0].multiply(ks[0]);
36         case 2:
37             return sumOfTwoMultiplies(ps[0], ks[0], ps[1], ks[1]);
38         default:
39             break;
40         }
41 
42         ECPoint p = ps[0];
43         ECCurve c = p.getCurve();
44 
45         ECPoint[] imported = new ECPoint[count];
46         imported[0] = p;
47         for (int i = 1; i < count; ++i)
48         {
49             imported[i] = importPoint(c, ps[i]);
50         }
51 
52         ECEndomorphism endomorphism = c.getEndomorphism();
53         if (endomorphism instanceof GLVEndomorphism)
54         {
55             return validatePoint(implSumOfMultipliesGLV(imported, ks, (GLVEndomorphism)endomorphism));
56         }
57 
58         return validatePoint(implSumOfMultiplies(imported, ks));
59     }
60 
sumOfTwoMultiplies(ECPoint P, BigInteger a, ECPoint Q, BigInteger b)61     public static ECPoint sumOfTwoMultiplies(ECPoint P, BigInteger a,
62         ECPoint Q, BigInteger b)
63     {
64         ECCurve cp = P.getCurve();
65         Q = importPoint(cp, Q);
66 
67         // Point multiplication for Koblitz curves (using WTNAF) beats Shamir's trick
68         if (cp instanceof ECCurve.F2m)
69         {
70             ECCurve.F2m f2mCurve = (ECCurve.F2m)cp;
71             if (f2mCurve.isKoblitz())
72             {
73                 return validatePoint(P.multiply(a).add(Q.multiply(b)));
74             }
75         }
76 
77         ECEndomorphism endomorphism = cp.getEndomorphism();
78         if (endomorphism instanceof GLVEndomorphism)
79         {
80             return validatePoint(
81                 implSumOfMultipliesGLV(new ECPoint[]{ P, Q }, new BigInteger[]{ a, b }, (GLVEndomorphism)endomorphism));
82         }
83 
84         return validatePoint(implShamirsTrickWNaf(P, a, Q, b));
85     }
86 
87     /*
88      * "Shamir's Trick", originally due to E. G. Straus
89      * (Addition chains of vectors. American Mathematical Monthly,
90      * 71(7):806-808, Aug./Sept. 1964)
91      * <pre>
92      * Input: The points P, Q, scalar k = (km?, ... , k1, k0)
93      * and scalar l = (lm?, ... , l1, l0).
94      * Output: R = k * P + l * Q.
95      * 1: Z <- P + Q
96      * 2: R <- O
97      * 3: for i from m-1 down to 0 do
98      * 4:        R <- R + R        {point doubling}
99      * 5:        if (ki = 1) and (li = 0) then R <- R + P end if
100      * 6:        if (ki = 0) and (li = 1) then R <- R + Q end if
101      * 7:        if (ki = 1) and (li = 1) then R <- R + Z end if
102      * 8: end for
103      * 9: return R
104      * </pre>
105      */
shamirsTrick(ECPoint P, BigInteger k, ECPoint Q, BigInteger l)106     public static ECPoint shamirsTrick(ECPoint P, BigInteger k,
107         ECPoint Q, BigInteger l)
108     {
109         ECCurve cp = P.getCurve();
110         Q = importPoint(cp, Q);
111 
112         return validatePoint(implShamirsTrickJsf(P, k, Q, l));
113     }
114 
importPoint(ECCurve c, ECPoint p)115     public static ECPoint importPoint(ECCurve c, ECPoint p)
116     {
117         ECCurve cp = p.getCurve();
118         if (!c.equals(cp))
119         {
120             throw new IllegalArgumentException("Point must be on the same curve");
121         }
122         return c.importPoint(p);
123     }
124 
montgomeryTrick(ECFieldElement[] zs, int off, int len)125     public static void montgomeryTrick(ECFieldElement[] zs, int off, int len)
126     {
127         montgomeryTrick(zs, off, len, null);
128     }
129 
montgomeryTrick(ECFieldElement[] zs, int off, int len, ECFieldElement scale)130     public static void montgomeryTrick(ECFieldElement[] zs, int off, int len, ECFieldElement scale)
131     {
132         /*
133          * Uses the "Montgomery Trick" to invert many field elements, with only a single actual
134          * field inversion. See e.g. the paper:
135          * "Fast Multi-scalar Multiplication Methods on Elliptic Curves with Precomputation Strategy Using Montgomery Trick"
136          * by Katsuyuki Okeya, Kouichi Sakurai.
137          */
138 
139         ECFieldElement[] c = new ECFieldElement[len];
140         c[0] = zs[off];
141 
142         int i = 0;
143         while (++i < len)
144         {
145             c[i] = c[i - 1].multiply(zs[off + i]);
146         }
147 
148         --i;
149 
150         if (scale != null)
151         {
152             c[i] = c[i].multiply(scale);
153         }
154 
155         ECFieldElement u = c[i].invert();
156 
157         while (i > 0)
158         {
159             int j = off + i--;
160             ECFieldElement tmp = zs[j];
161             zs[j] = c[i].multiply(u);
162             u = u.multiply(tmp);
163         }
164 
165         zs[off] = u;
166     }
167 
168     /**
169      * Simple shift-and-add multiplication. Serves as reference implementation
170      * to verify (possibly faster) implementations, and for very small scalars.
171      *
172      * @param p
173      *            The point to multiply.
174      * @param k
175      *            The multiplier.
176      * @return The result of the point multiplication <code>kP</code>.
177      */
referenceMultiply(ECPoint p, BigInteger k)178     public static ECPoint referenceMultiply(ECPoint p, BigInteger k)
179     {
180         BigInteger x = k.abs();
181         ECPoint q = p.getCurve().getInfinity();
182         int t = x.bitLength();
183         if (t > 0)
184         {
185             if (x.testBit(0))
186             {
187                 q = p;
188             }
189             for (int i = 1; i < t; i++)
190             {
191                 p = p.twice();
192                 if (x.testBit(i))
193                 {
194                     q = q.add(p);
195                 }
196             }
197         }
198         return k.signum() < 0 ? q.negate() : q;
199     }
200 
validatePoint(ECPoint p)201     public static ECPoint validatePoint(ECPoint p)
202     {
203         if (!p.isValid())
204         {
205             throw new IllegalArgumentException("Invalid point");
206         }
207 
208         return p;
209     }
210 
implShamirsTrickJsf(ECPoint P, BigInteger k, ECPoint Q, BigInteger l)211     static ECPoint implShamirsTrickJsf(ECPoint P, BigInteger k,
212         ECPoint Q, BigInteger l)
213     {
214         ECCurve curve = P.getCurve();
215         ECPoint infinity = curve.getInfinity();
216 
217         // TODO conjugate co-Z addition (ZADDC) can return both of these
218         ECPoint PaddQ = P.add(Q);
219         ECPoint PsubQ = P.subtract(Q);
220 
221         ECPoint[] points = new ECPoint[]{ Q, PsubQ, P, PaddQ };
222         curve.normalizeAll(points);
223 
224         ECPoint[] table = new ECPoint[] {
225             points[3].negate(), points[2].negate(), points[1].negate(),
226             points[0].negate(), infinity, points[0],
227             points[1], points[2], points[3] };
228 
229         byte[] jsf = WNafUtil.generateJSF(k, l);
230 
231         ECPoint R = infinity;
232 
233         int i = jsf.length;
234         while (--i >= 0)
235         {
236             int jsfi = jsf[i];
237 
238             // NOTE: The shifting ensures the sign is extended correctly
239             int kDigit = ((jsfi << 24) >> 28), lDigit = ((jsfi << 28) >> 28);
240 
241             int index = 4 + (kDigit * 3) + lDigit;
242             R = R.twicePlus(table[index]);
243         }
244 
245         return R;
246     }
247 
implShamirsTrickWNaf(ECPoint P, BigInteger k, ECPoint Q, BigInteger l)248     static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k,
249         ECPoint Q, BigInteger l)
250     {
251         boolean negK = k.signum() < 0, negL = l.signum() < 0;
252 
253         k = k.abs();
254         l = l.abs();
255 
256         int widthP = Math.max(2, Math.min(16, WNafUtil.getWindowSize(k.bitLength())));
257         int widthQ = Math.max(2, Math.min(16, WNafUtil.getWindowSize(l.bitLength())));
258 
259         WNafPreCompInfo infoP = WNafUtil.precompute(P, widthP, true);
260         WNafPreCompInfo infoQ = WNafUtil.precompute(Q, widthQ, true);
261 
262         ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
263         ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
264         ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
265         ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
266 
267         byte[] wnafP = WNafUtil.generateWindowNaf(widthP, k);
268         byte[] wnafQ = WNafUtil.generateWindowNaf(widthQ, l);
269 
270         return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
271     }
272 
273     static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k, ECPointMap pointMapQ, BigInteger l)
274     {
275         boolean negK = k.signum() < 0, negL = l.signum() < 0;
276 
277         k = k.abs();
278         l = l.abs();
279 
280         int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(k.bitLength(), l.bitLength()))));
281 
282         ECPoint Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMapQ);
283         WNafPreCompInfo infoP = WNafUtil.getWNafPreCompInfo(P);
284         WNafPreCompInfo infoQ = WNafUtil.getWNafPreCompInfo(Q);
285 
286         ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
287         ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
288         ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
289         ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
290 
291         byte[] wnafP = WNafUtil.generateWindowNaf(width, k);
292         byte[] wnafQ = WNafUtil.generateWindowNaf(width, l);
293 
294         return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
295     }
296 
297     private static ECPoint implShamirsTrickWNaf(ECPoint[] preCompP, ECPoint[] preCompNegP, byte[] wnafP,
298         ECPoint[] preCompQ, ECPoint[] preCompNegQ, byte[] wnafQ)
299     {
300         int len = Math.max(wnafP.length, wnafQ.length);
301 
302         ECCurve curve = preCompP[0].getCurve();
303         ECPoint infinity = curve.getInfinity();
304 
305         ECPoint R = infinity;
306         int zeroes = 0;
307 
308         for (int i = len - 1; i >= 0; --i)
309         {
310             int wiP = i < wnafP.length ? wnafP[i] : 0;
311             int wiQ = i < wnafQ.length ? wnafQ[i] : 0;
312 
313             if ((wiP | wiQ) == 0)
314             {
315                 ++zeroes;
316                 continue;
317             }
318 
319             ECPoint r = infinity;
320             if (wiP != 0)
321             {
322                 int nP = Math.abs(wiP);
323                 ECPoint[] tableP = wiP < 0 ? preCompNegP : preCompP;
324                 r = r.add(tableP[nP >>> 1]);
325             }
326             if (wiQ != 0)
327             {
328                 int nQ = Math.abs(wiQ);
329                 ECPoint[] tableQ = wiQ < 0 ? preCompNegQ : preCompQ;
330                 r = r.add(tableQ[nQ >>> 1]);
331             }
332 
333             if (zeroes > 0)
334             {
335                 R = R.timesPow2(zeroes);
336                 zeroes = 0;
337             }
338 
339             R = R.twicePlus(r);
340         }
341 
342         if (zeroes > 0)
343         {
344             R = R.timesPow2(zeroes);
345         }
346 
347         return R;
348     }
349 
implSumOfMultiplies(ECPoint[] ps, BigInteger[] ks)350     static ECPoint implSumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
351     {
352         int count = ps.length;
353         boolean[] negs = new boolean[count];
354         WNafPreCompInfo[] infos = new WNafPreCompInfo[count];
355         byte[][] wnafs = new byte[count][];
356 
357         for (int i = 0; i < count; ++i)
358         {
359             BigInteger ki = ks[i]; negs[i] = ki.signum() < 0; ki = ki.abs();
360 
361             int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(ki.bitLength())));
362             infos[i] = WNafUtil.precompute(ps[i], width, true);
363             wnafs[i] = WNafUtil.generateWindowNaf(width, ki);
364         }
365 
366         return implSumOfMultiplies(negs, infos, wnafs);
367     }
368 
369     static ECPoint implSumOfMultipliesGLV(ECPoint[] ps, BigInteger[] ks, GLVEndomorphism glvEndomorphism)
370     {
371         BigInteger n = ps[0].getCurve().getOrder();
372 
373         int len = ps.length;
374 
375         BigInteger[] abs = new BigInteger[len << 1];
376         for (int i = 0, j = 0; i < len; ++i)
377         {
378             BigInteger[] ab = glvEndomorphism.decomposeScalar(ks[i].mod(n));
379             abs[j++] = ab[0];
380             abs[j++] = ab[1];
381         }
382 
383         ECPointMap pointMap = glvEndomorphism.getPointMap();
384         if (glvEndomorphism.hasEfficientPointMap())
385         {
386             return ECAlgorithms.implSumOfMultiplies(ps, pointMap, abs);
387         }
388 
389         ECPoint[] pqs = new ECPoint[len << 1];
390         for (int i = 0, j = 0; i < len; ++i)
391         {
392             ECPoint p = ps[i], q = pointMap.map(p);
393             pqs[j++] = p;
394             pqs[j++] = q;
395         }
396 
397         return ECAlgorithms.implSumOfMultiplies(pqs, abs);
398 
399     }
400 
401     static ECPoint implSumOfMultiplies(ECPoint[] ps, ECPointMap pointMap, BigInteger[] ks)
402     {
403         int halfCount = ps.length, fullCount = halfCount << 1;
404 
405         boolean[] negs = new boolean[fullCount];
406         WNafPreCompInfo[] infos = new WNafPreCompInfo[fullCount];
407         byte[][] wnafs = new byte[fullCount][];
408 
409         for (int i = 0; i < halfCount; ++i)
410         {
411             int j0 = i << 1, j1 = j0 + 1;
412 
413             BigInteger kj0 = ks[j0]; negs[j0] = kj0.signum() < 0; kj0 = kj0.abs();
414             BigInteger kj1 = ks[j1]; negs[j1] = kj1.signum() < 0; kj1 = kj1.abs();
415 
416             int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(kj0.bitLength(), kj1.bitLength()))));
417 
418             ECPoint P = ps[i], Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMap);
419             infos[j0] = WNafUtil.getWNafPreCompInfo(P);
420             infos[j1] = WNafUtil.getWNafPreCompInfo(Q);
421             wnafs[j0] = WNafUtil.generateWindowNaf(width, kj0);
422             wnafs[j1] = WNafUtil.generateWindowNaf(width, kj1);
423         }
424 
425         return implSumOfMultiplies(negs, infos, wnafs);
426     }
427 
428     private static ECPoint implSumOfMultiplies(boolean[] negs, WNafPreCompInfo[] infos, byte[][] wnafs)
429     {
430         int len = 0, count = wnafs.length;
431         for (int i = 0; i < count; ++i)
432         {
433             len = Math.max(len, wnafs[i].length);
434         }
435 
436         ECCurve curve = infos[0].getPreComp()[0].getCurve();
437         ECPoint infinity = curve.getInfinity();
438 
439         ECPoint R = infinity;
440         int zeroes = 0;
441 
442         for (int i = len - 1; i >= 0; --i)
443         {
444             ECPoint r = infinity;
445 
446             for (int j = 0; j < count; ++j)
447             {
448                 byte[] wnaf = wnafs[j];
449                 int wi = i < wnaf.length ? wnaf[i] : 0;
450                 if (wi != 0)
451                 {
452                     int n = Math.abs(wi);
453                     WNafPreCompInfo info = infos[j];
454                     ECPoint[] table = (wi < 0 == negs[j]) ? info.getPreComp() : info.getPreCompNeg();
455                     r = r.add(table[n >>> 1]);
456                 }
457             }
458 
459             if (r == infinity)
460             {
461                 ++zeroes;
462                 continue;
463             }
464 
465             if (zeroes > 0)
466             {
467                 R = R.timesPow2(zeroes);
468                 zeroes = 0;
469             }
470 
471             R = R.twicePlus(r);
472         }
473 
474         if (zeroes > 0)
475         {
476             R = R.timesPow2(zeroes);
477         }
478 
479         return R;
480     }
481 }
482