1// Copyright (c) 2016, Google Inc. 2// 3// Permission to use, copy, modify, and/or distribute this software for any 4// purpose with or without fee is hereby granted, provided that the above 5// copyright notice and this permission notice appear in all copies. 6// 7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 15package main 16 17import ( 18 "bufio" 19 "errors" 20 "fmt" 21 "io" 22 "math/big" 23 "os" 24 "strings" 25) 26 27type test struct { 28 LineNumber int 29 Type string 30 Values map[string]*big.Int 31} 32 33type testScanner struct { 34 scanner *bufio.Scanner 35 lineNo int 36 err error 37 test test 38} 39 40func newTestScanner(r io.Reader) *testScanner { 41 return &testScanner{scanner: bufio.NewScanner(r)} 42} 43 44func (s *testScanner) scanLine() bool { 45 if !s.scanner.Scan() { 46 return false 47 } 48 s.lineNo++ 49 return true 50} 51 52func (s *testScanner) addAttribute(line string) (key string, ok bool) { 53 fields := strings.SplitN(line, "=", 2) 54 if len(fields) != 2 { 55 s.setError(errors.New("invalid syntax")) 56 return "", false 57 } 58 59 key = strings.TrimSpace(fields[0]) 60 value := strings.TrimSpace(fields[1]) 61 62 valueInt, ok := new(big.Int).SetString(value, 16) 63 if !ok { 64 s.setError(fmt.Errorf("could not parse %q", value)) 65 return "", false 66 } 67 if _, dup := s.test.Values[key]; dup { 68 s.setError(fmt.Errorf("duplicate key %q", key)) 69 return "", false 70 } 71 s.test.Values[key] = valueInt 72 return key, true 73} 74 75func (s *testScanner) Scan() bool { 76 s.test = test{ 77 Values: make(map[string]*big.Int), 78 } 79 80 // Scan until the first attribute. 81 for { 82 if !s.scanLine() { 83 return false 84 } 85 if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' { 86 break 87 } 88 } 89 90 var ok bool 91 s.test.Type, ok = s.addAttribute(s.scanner.Text()) 92 if !ok { 93 return false 94 } 95 s.test.LineNumber = s.lineNo 96 97 for s.scanLine() { 98 if len(s.scanner.Text()) == 0 { 99 break 100 } 101 102 if s.scanner.Text()[0] == '#' { 103 continue 104 } 105 106 if _, ok := s.addAttribute(s.scanner.Text()); !ok { 107 return false 108 } 109 } 110 return s.scanner.Err() == nil 111} 112 113func (s *testScanner) Test() test { 114 return s.test 115} 116 117func (s *testScanner) Err() error { 118 if s.err != nil { 119 return s.err 120 } 121 return s.scanner.Err() 122} 123 124func (s *testScanner) setError(err error) { 125 s.err = fmt.Errorf("line %d: %s", s.lineNo, err) 126} 127 128func checkKeys(t test, keys ...string) bool { 129 var foundErrors bool 130 131 for _, k := range keys { 132 if _, ok := t.Values[k]; !ok { 133 fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k) 134 foundErrors = true 135 } 136 } 137 138 for k, _ := range t.Values { 139 var found bool 140 for _, k2 := range keys { 141 if k == k2 { 142 found = true 143 break 144 } 145 } 146 if !found { 147 fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k) 148 foundErrors = true 149 } 150 } 151 152 return !foundErrors 153} 154 155func checkResult(t test, expr, key string, r *big.Int) { 156 if t.Values[key].Cmp(r) != 0 { 157 fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16)) 158 } 159} 160 161func main() { 162 if len(os.Args) != 2 { 163 fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0]) 164 os.Exit(1) 165 } 166 167 in, err := os.Open(os.Args[1]) 168 if err != nil { 169 fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err) 170 os.Exit(1) 171 } 172 defer in.Close() 173 174 scanner := newTestScanner(in) 175 for scanner.Scan() { 176 test := scanner.Test() 177 switch test.Type { 178 case "Sum": 179 if checkKeys(test, "A", "B", "Sum") { 180 r := new(big.Int).Add(test.Values["A"], test.Values["B"]) 181 checkResult(test, "A + B", "Sum", r) 182 } 183 case "LShift1": 184 if checkKeys(test, "A", "LShift1") { 185 r := new(big.Int).Add(test.Values["A"], test.Values["A"]) 186 checkResult(test, "A + A", "LShift1", r) 187 } 188 case "LShift": 189 if checkKeys(test, "A", "N", "LShift") { 190 r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64())) 191 checkResult(test, "A << N", "LShift", r) 192 } 193 case "RShift": 194 if checkKeys(test, "A", "N", "RShift") { 195 r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64())) 196 checkResult(test, "A >> N", "RShift", r) 197 } 198 case "Square": 199 if checkKeys(test, "A", "Square") { 200 r := new(big.Int).Mul(test.Values["A"], test.Values["A"]) 201 checkResult(test, "A * A", "Square", r) 202 } 203 case "Product": 204 if checkKeys(test, "A", "B", "Product") { 205 r := new(big.Int).Mul(test.Values["A"], test.Values["B"]) 206 checkResult(test, "A * B", "Product", r) 207 } 208 case "Quotient": 209 if checkKeys(test, "A", "B", "Quotient", "Remainder") { 210 q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int)) 211 checkResult(test, "A / B", "Quotient", q) 212 checkResult(test, "A % B", "Remainder", r) 213 } 214 case "ModMul": 215 if checkKeys(test, "A", "B", "M", "ModMul") { 216 r := new(big.Int).Mul(test.Values["A"], test.Values["B"]) 217 r = r.Mod(r, test.Values["M"]) 218 checkResult(test, "A * B (mod M)", "ModMul", r) 219 } 220 case "ModExp": 221 if checkKeys(test, "A", "E", "M", "ModExp") { 222 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"]) 223 checkResult(test, "A ^ E (mod M)", "ModExp", r) 224 } 225 case "Exp": 226 if checkKeys(test, "A", "E", "Exp") { 227 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil) 228 checkResult(test, "A ^ E", "Exp", r) 229 } 230 case "ModSqrt": 231 bigOne := new(big.Int).SetInt64(1) 232 bigTwo := new(big.Int).SetInt64(2) 233 234 if checkKeys(test, "A", "P", "ModSqrt") { 235 test.Values["A"].Mod(test.Values["A"], test.Values["P"]) 236 237 r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"]) 238 r = r.Mod(r, test.Values["P"]) 239 checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r) 240 241 if test.Values["P"].Cmp(bigTwo) > 0 { 242 pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne) 243 pMinus1Over2.Rsh(pMinus1Over2, 1) 244 245 if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 { 246 fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber) 247 } 248 } 249 } 250 case "ModInv": 251 if checkKeys(test, "A", "M", "ModInv") { 252 r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"]) 253 checkResult(test, "A ^ -1 (mod M)", "ModInv", r) 254 } 255 default: 256 fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type) 257 } 258 } 259 if scanner.Err() != nil { 260 fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err()) 261 } 262} 263