1// Copyright (c) 2017, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package main
16
17import (
18	"bufio"
19	"crypto/sha1"
20	"encoding/hex"
21	"errors"
22	"fmt"
23	"io"
24	"io/ioutil"
25	"math/big"
26	"os"
27	"path/filepath"
28	"strings"
29)
30
31type test struct {
32	LineNumber int
33	Type       string
34	Values     map[string]*big.Int
35}
36
37type testScanner struct {
38	scanner *bufio.Scanner
39	lineNo  int
40	err     error
41	test    test
42}
43
44func newTestScanner(r io.Reader) *testScanner {
45	return &testScanner{scanner: bufio.NewScanner(r)}
46}
47
48func (s *testScanner) scanLine() bool {
49	if !s.scanner.Scan() {
50		return false
51	}
52	s.lineNo++
53	return true
54}
55
56func (s *testScanner) addAttribute(line string) (key string, ok bool) {
57	fields := strings.SplitN(line, "=", 2)
58	if len(fields) != 2 {
59		s.setError(errors.New("invalid syntax"))
60		return "", false
61	}
62
63	key = strings.TrimSpace(fields[0])
64	value := strings.TrimSpace(fields[1])
65
66	valueInt, ok := new(big.Int).SetString(value, 16)
67	if !ok {
68		s.setError(fmt.Errorf("could not parse %q", value))
69		return "", false
70	}
71	if _, dup := s.test.Values[key]; dup {
72		s.setError(fmt.Errorf("duplicate key %q", key))
73		return "", false
74	}
75	s.test.Values[key] = valueInt
76	return key, true
77}
78
79func (s *testScanner) Scan() bool {
80	s.test = test{
81		Values: make(map[string]*big.Int),
82	}
83
84	// Scan until the first attribute.
85	for {
86		if !s.scanLine() {
87			return false
88		}
89		if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
90			break
91		}
92	}
93
94	var ok bool
95	s.test.Type, ok = s.addAttribute(s.scanner.Text())
96	if !ok {
97		return false
98	}
99	s.test.LineNumber = s.lineNo
100
101	for s.scanLine() {
102		if len(s.scanner.Text()) == 0 {
103			break
104		}
105
106		if s.scanner.Text()[0] == '#' {
107			continue
108		}
109
110		if _, ok := s.addAttribute(s.scanner.Text()); !ok {
111			return false
112		}
113	}
114	return s.scanner.Err() == nil
115}
116
117func (s *testScanner) Test() test {
118	return s.test
119}
120
121func (s *testScanner) Err() error {
122	if s.err != nil {
123		return s.err
124	}
125	return s.scanner.Err()
126}
127
128func (s *testScanner) setError(err error) {
129	s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
130}
131
132func checkKeys(t test, keys ...string) bool {
133	var foundErrors bool
134
135	for _, k := range keys {
136		if _, ok := t.Values[k]; !ok {
137			fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
138			foundErrors = true
139		}
140	}
141
142	for k, _ := range t.Values {
143		var found bool
144		for _, k2 := range keys {
145			if k == k2 {
146				found = true
147				break
148			}
149		}
150		if !found {
151			fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
152			foundErrors = true
153		}
154	}
155
156	return !foundErrors
157}
158
159func appendLengthPrefixed(ret, b []byte) []byte {
160	ret = append(ret, byte(len(b)>>8), byte(len(b)))
161	ret = append(ret, b...)
162	return ret
163}
164
165func appendUnsigned(ret []byte, n *big.Int) []byte {
166	b := n.Bytes()
167	if n.Sign() == 0 {
168		b = []byte{0}
169	}
170	return appendLengthPrefixed(ret, b)
171}
172
173func appendSigned(ret []byte, n *big.Int) []byte {
174	var sign byte
175	if n.Sign() < 0 {
176		sign = 1
177	}
178	b := []byte{sign}
179	b = append(b, n.Bytes()...)
180	if n.Sign() == 0 {
181		b = append(b, 0)
182	}
183	return appendLengthPrefixed(ret, b)
184}
185
186func main() {
187	if len(os.Args) != 3 {
188		fmt.Fprintf(os.Stderr, "Usage: %s TESTS FUZZ_DIR\n", os.Args[0])
189		os.Exit(1)
190	}
191
192	in, err := os.Open(os.Args[1])
193	if err != nil {
194		fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
195		os.Exit(1)
196	}
197	defer in.Close()
198
199	fuzzerDir := os.Args[2]
200
201	scanner := newTestScanner(in)
202	for scanner.Scan() {
203		var fuzzer string
204		var b []byte
205		test := scanner.Test()
206		switch test.Type {
207		case "Quotient":
208			if checkKeys(test, "A", "B", "Quotient", "Remainder") {
209				fuzzer = "bn_div"
210				b = appendSigned(b, test.Values["A"])
211				b = appendSigned(b, test.Values["B"])
212			}
213		case "ModExp":
214			if checkKeys(test, "A", "E", "M", "ModExp") {
215				fuzzer = "bn_mod_exp"
216				b = appendSigned(b, test.Values["A"])
217				b = appendUnsigned(b, test.Values["E"])
218				b = appendUnsigned(b, test.Values["M"])
219			}
220		}
221
222		if len(fuzzer) != 0 {
223			hash := sha1.Sum(b)
224			path := filepath.Join(fuzzerDir, fuzzer + "_corpus", hex.EncodeToString(hash[:]))
225			if err := ioutil.WriteFile(path, b, 0666); err != nil {
226				fmt.Fprintf(os.Stderr, "Error writing to %s: %s.\n", path, err)
227				os.Exit(1)
228			}
229		}
230	}
231	if scanner.Err() != nil {
232		fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
233	}
234}
235