1// Copyright (c) 2016, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package main
16
17import (
18	"bufio"
19	"errors"
20	"fmt"
21	"io"
22	"math/big"
23	"os"
24	"strings"
25)
26
27type test struct {
28	LineNumber int
29	Type       string
30	Values     map[string]*big.Int
31}
32
33type testScanner struct {
34	scanner *bufio.Scanner
35	lineNo  int
36	err     error
37	test    test
38}
39
40func newTestScanner(r io.Reader) *testScanner {
41	return &testScanner{scanner: bufio.NewScanner(r)}
42}
43
44func (s *testScanner) scanLine() bool {
45	if !s.scanner.Scan() {
46		return false
47	}
48	s.lineNo++
49	return true
50}
51
52func (s *testScanner) addAttribute(line string) (key string, ok bool) {
53	fields := strings.SplitN(line, "=", 2)
54	if len(fields) != 2 {
55		s.setError(errors.New("invalid syntax"))
56		return "", false
57	}
58
59	key = strings.TrimSpace(fields[0])
60	value := strings.TrimSpace(fields[1])
61
62	valueInt, ok := new(big.Int).SetString(value, 16)
63	if !ok {
64		s.setError(fmt.Errorf("could not parse %q", value))
65		return "", false
66	}
67	if _, dup := s.test.Values[key]; dup {
68		s.setError(fmt.Errorf("duplicate key %q", key))
69		return "", false
70	}
71	s.test.Values[key] = valueInt
72	return key, true
73}
74
75func (s *testScanner) Scan() bool {
76	s.test = test{
77		Values: make(map[string]*big.Int),
78	}
79
80	// Scan until the first attribute.
81	for {
82		if !s.scanLine() {
83			return false
84		}
85		if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
86			break
87		}
88	}
89
90	var ok bool
91	s.test.Type, ok = s.addAttribute(s.scanner.Text())
92	if !ok {
93		return false
94	}
95	s.test.LineNumber = s.lineNo
96
97	for s.scanLine() {
98		if len(s.scanner.Text()) == 0 {
99			break
100		}
101
102		if s.scanner.Text()[0] == '#' {
103			continue
104		}
105
106		if _, ok := s.addAttribute(s.scanner.Text()); !ok {
107			return false
108		}
109	}
110	return s.scanner.Err() == nil
111}
112
113func (s *testScanner) Test() test {
114	return s.test
115}
116
117func (s *testScanner) Err() error {
118	if s.err != nil {
119		return s.err
120	}
121	return s.scanner.Err()
122}
123
124func (s *testScanner) setError(err error) {
125	s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
126}
127
128func checkKeys(t test, keys ...string) bool {
129	var foundErrors bool
130
131	for _, k := range keys {
132		if _, ok := t.Values[k]; !ok {
133			fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
134			foundErrors = true
135		}
136	}
137
138	for k, _ := range t.Values {
139		var found bool
140		for _, k2 := range keys {
141			if k == k2 {
142				found = true
143				break
144			}
145		}
146		if !found {
147			fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
148			foundErrors = true
149		}
150	}
151
152	return !foundErrors
153}
154
155func checkResult(t test, expr, key string, r *big.Int) {
156	if t.Values[key].Cmp(r) != 0 {
157		fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16))
158	}
159}
160
161func main() {
162	if len(os.Args) != 2 {
163		fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0])
164		os.Exit(1)
165	}
166
167	in, err := os.Open(os.Args[1])
168	if err != nil {
169		fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
170		os.Exit(1)
171	}
172	defer in.Close()
173
174	scanner := newTestScanner(in)
175	for scanner.Scan() {
176		test := scanner.Test()
177		switch test.Type {
178		case "Sum":
179			if checkKeys(test, "A", "B", "Sum") {
180				r := new(big.Int).Add(test.Values["A"], test.Values["B"])
181				checkResult(test, "A + B", "Sum", r)
182			}
183		case "LShift1":
184			if checkKeys(test, "A", "LShift1") {
185				r := new(big.Int).Add(test.Values["A"], test.Values["A"])
186				checkResult(test, "A + A", "LShift1", r)
187			}
188		case "LShift":
189			if checkKeys(test, "A", "N", "LShift") {
190				r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64()))
191				checkResult(test, "A << N", "LShift", r)
192			}
193		case "RShift":
194			if checkKeys(test, "A", "N", "RShift") {
195				r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64()))
196				checkResult(test, "A >> N", "RShift", r)
197			}
198		case "Square":
199			if checkKeys(test, "A", "Square") {
200				r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
201				checkResult(test, "A * A", "Square", r)
202			}
203		case "Product":
204			if checkKeys(test, "A", "B", "Product") {
205				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
206				checkResult(test, "A * B", "Product", r)
207			}
208		case "Quotient":
209			if checkKeys(test, "A", "B", "Quotient", "Remainder") {
210				q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int))
211				checkResult(test, "A / B", "Quotient", q)
212				checkResult(test, "A % B", "Remainder", r)
213			}
214		case "ModMul":
215			if checkKeys(test, "A", "B", "M", "ModMul") {
216				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
217				r = r.Mod(r, test.Values["M"])
218				checkResult(test, "A * B (mod M)", "ModMul", r)
219			}
220		case "ModExp":
221			if checkKeys(test, "A", "E", "M", "ModExp") {
222				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"])
223				checkResult(test, "A ^ E (mod M)", "ModExp", r)
224			}
225		case "Exp":
226			if checkKeys(test, "A", "E", "Exp") {
227				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil)
228				checkResult(test, "A ^ E", "Exp", r)
229			}
230		case "ModSqrt":
231			bigOne := new(big.Int).SetInt64(1)
232			bigTwo := new(big.Int).SetInt64(2)
233
234			if checkKeys(test, "A", "P", "ModSqrt") {
235				test.Values["A"].Mod(test.Values["A"], test.Values["P"])
236
237				r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"])
238				r = r.Mod(r, test.Values["P"])
239				checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r)
240
241				if test.Values["P"].Cmp(bigTwo) > 0 {
242					pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne)
243					pMinus1Over2.Rsh(pMinus1Over2, 1)
244
245					if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 {
246						fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber)
247					}
248				}
249			}
250		case "ModInv":
251			if checkKeys(test, "A", "M", "ModInv") {
252				r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"])
253				checkResult(test, "A ^ -1 (mod M)", "ModInv", r)
254			}
255		default:
256			fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type)
257		}
258	}
259	if scanner.Err() != nil {
260		fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
261	}
262}
263