1// Copyright 2016 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package state
5
6import (
7	"fmt"
8	"io/ioutil"
9	"os"
10	"path/filepath"
11	"sort"
12	"strconv"
13	"time"
14
15	"github.com/google/syzkaller/pkg/db"
16	"github.com/google/syzkaller/pkg/hash"
17	"github.com/google/syzkaller/pkg/log"
18	"github.com/google/syzkaller/pkg/osutil"
19	"github.com/google/syzkaller/prog"
20)
21
22// State holds all internal syz-hub state including corpus,
23// reproducers and information about managers.
24// It is persisted to and can be restored from a directory.
25type State struct {
26	corpusSeq uint64
27	reproSeq  uint64
28	dir       string
29	Corpus    *db.DB
30	Repros    *db.DB
31	Managers  map[string]*Manager
32}
33
34// Manager represents one syz-manager instance.
35type Manager struct {
36	name          string
37	corpusSeq     uint64
38	reproSeq      uint64
39	corpusFile    string
40	corpusSeqFile string
41	reproSeqFile  string
42	ownRepros     map[string]bool
43	Connected     time.Time
44	Added         int
45	Deleted       int
46	New           int
47	SentRepros    int
48	RecvRepros    int
49	Calls         map[string]struct{}
50	Corpus        *db.DB
51}
52
53// Make creates State and initializes it from dir.
54func Make(dir string) (*State, error) {
55	st := &State{
56		dir:      dir,
57		Managers: make(map[string]*Manager),
58	}
59
60	osutil.MkdirAll(st.dir)
61	st.Corpus, st.corpusSeq = loadDB(filepath.Join(st.dir, "corpus.db"), "corpus")
62	st.Repros, st.reproSeq = loadDB(filepath.Join(st.dir, "repro.db"), "repro")
63
64	managersDir := filepath.Join(st.dir, "manager")
65	osutil.MkdirAll(managersDir)
66	managers, err := ioutil.ReadDir(managersDir)
67	if err != nil {
68		return nil, fmt.Errorf("failed to read %v dir: %v", managersDir, err)
69	}
70	for _, manager := range managers {
71		_, err := st.createManager(manager.Name())
72		if err != nil {
73			return nil, err
74		}
75	}
76	log.Logf(0, "purging corpus...")
77	st.purgeCorpus()
78	log.Logf(0, "done, %v programs", len(st.Corpus.Records))
79
80	return st, err
81}
82
83func loadDB(file, name string) (*db.DB, uint64) {
84	log.Logf(0, "reading %v...", name)
85	db, err := db.Open(file)
86	if err != nil {
87		log.Fatalf("failed to open %v database: %v", name, err)
88	}
89	log.Logf(0, "read %v programs", len(db.Records))
90	var maxSeq uint64
91	for key, rec := range db.Records {
92		if _, err := prog.CallSet(rec.Val); err != nil {
93			log.Logf(0, "bad file: can't parse call set: %v", err)
94			db.Delete(key)
95			continue
96		}
97		if sig := hash.Hash(rec.Val); sig.String() != key {
98			log.Logf(0, "bad file: hash %v, want hash %v", key, sig.String())
99			db.Delete(key)
100			continue
101		}
102		if maxSeq < rec.Seq {
103			maxSeq = rec.Seq
104		}
105	}
106	if err := db.Flush(); err != nil {
107		log.Fatalf("failed to flush corpus database: %v", err)
108	}
109	return db, maxSeq
110}
111
112func (st *State) createManager(name string) (*Manager, error) {
113	dir := filepath.Join(st.dir, "manager", name)
114	osutil.MkdirAll(dir)
115	mgr := &Manager{
116		name:          name,
117		corpusFile:    filepath.Join(dir, "corpus.db"),
118		corpusSeqFile: filepath.Join(dir, "seq"),
119		reproSeqFile:  filepath.Join(dir, "repro.seq"),
120		ownRepros:     make(map[string]bool),
121	}
122	mgr.corpusSeq = loadSeqFile(mgr.corpusSeqFile)
123	if st.corpusSeq < mgr.corpusSeq {
124		st.corpusSeq = mgr.corpusSeq
125	}
126	mgr.reproSeq = loadSeqFile(mgr.reproSeqFile)
127	if mgr.reproSeq == 0 {
128		mgr.reproSeq = st.reproSeq
129	}
130	if st.reproSeq < mgr.reproSeq {
131		st.reproSeq = mgr.reproSeq
132	}
133	var err error
134	mgr.Corpus, err = db.Open(mgr.corpusFile)
135	if err != nil {
136		return nil, fmt.Errorf("failed to open manager corpus %v: %v", mgr.corpusFile, err)
137	}
138	log.Logf(0, "created manager %v: corpus=%v, corpusSeq=%v, reproSeq=%v",
139		mgr.name, len(mgr.Corpus.Records), mgr.corpusSeq, mgr.reproSeq)
140	st.Managers[name] = mgr
141	return mgr, nil
142}
143
144func (st *State) Connect(name string, fresh bool, calls []string, corpus [][]byte) error {
145	mgr := st.Managers[name]
146	if mgr == nil {
147		var err error
148		mgr, err = st.createManager(name)
149		if err != nil {
150			return err
151		}
152	}
153	mgr.Connected = time.Now()
154	if fresh {
155		mgr.corpusSeq = 0
156		mgr.reproSeq = st.reproSeq
157	}
158	saveSeqFile(mgr.corpusSeqFile, mgr.corpusSeq)
159	saveSeqFile(mgr.reproSeqFile, mgr.reproSeq)
160
161	mgr.Calls = make(map[string]struct{})
162	for _, c := range calls {
163		mgr.Calls[c] = struct{}{}
164	}
165
166	os.Remove(mgr.corpusFile)
167	var err error
168	mgr.Corpus, err = db.Open(mgr.corpusFile)
169	if err != nil {
170		log.Logf(0, "failed to open corpus database: %v", err)
171		return err
172	}
173	st.addInputs(mgr, corpus)
174	st.purgeCorpus()
175	return nil
176}
177
178func (st *State) Sync(name string, add [][]byte, del []string) ([][]byte, int, error) {
179	mgr := st.Managers[name]
180	if mgr == nil || mgr.Connected.IsZero() {
181		return nil, 0, fmt.Errorf("unconnected manager %v", name)
182	}
183	if len(del) != 0 {
184		for _, sig := range del {
185			mgr.Corpus.Delete(sig)
186		}
187		if err := mgr.Corpus.Flush(); err != nil {
188			log.Logf(0, "failed to flush corpus database: %v", err)
189		}
190		st.purgeCorpus()
191	}
192	st.addInputs(mgr, add)
193	progs, more, err := st.pendingInputs(mgr)
194	mgr.Added += len(add)
195	mgr.Deleted += len(del)
196	mgr.New += len(progs)
197	return progs, more, err
198}
199
200func (st *State) AddRepro(name string, repro []byte) error {
201	mgr := st.Managers[name]
202	if mgr == nil || mgr.Connected.IsZero() {
203		return fmt.Errorf("unconnected manager %v", name)
204	}
205	if _, err := prog.CallSet(repro); err != nil {
206		log.Logf(0, "manager %v: failed to extract call set: %v, program:\n%v",
207			mgr.name, err, string(repro))
208		return nil
209	}
210	sig := hash.String(repro)
211	if _, ok := st.Repros.Records[sig]; ok {
212		return nil
213	}
214	mgr.ownRepros[sig] = true
215	mgr.SentRepros++
216	if mgr.reproSeq == st.reproSeq {
217		mgr.reproSeq++
218		saveSeqFile(mgr.reproSeqFile, mgr.reproSeq)
219	}
220	st.reproSeq++
221	st.Repros.Save(sig, repro, st.reproSeq)
222	if err := st.Repros.Flush(); err != nil {
223		log.Logf(0, "failed to flush repro database: %v", err)
224	}
225	return nil
226}
227
228func (st *State) PendingRepro(name string) ([]byte, error) {
229	mgr := st.Managers[name]
230	if mgr == nil || mgr.Connected.IsZero() {
231		return nil, fmt.Errorf("unconnected manager %v", name)
232	}
233	if mgr.reproSeq == st.reproSeq {
234		return nil, nil
235	}
236	var repro []byte
237	minSeq := ^uint64(0)
238	for key, rec := range st.Repros.Records {
239		if mgr.reproSeq >= rec.Seq {
240			continue
241		}
242		if mgr.ownRepros[key] {
243			continue
244		}
245		calls, err := prog.CallSet(rec.Val)
246		if err != nil {
247			return nil, fmt.Errorf("failed to extract call set: %v\nprogram: %s", err, rec.Val)
248		}
249		if !managerSupportsAllCalls(mgr.Calls, calls) {
250			continue
251		}
252		if minSeq > rec.Seq {
253			minSeq = rec.Seq
254			repro = rec.Val
255		}
256	}
257	if repro == nil {
258		mgr.reproSeq = st.reproSeq
259		saveSeqFile(mgr.reproSeqFile, mgr.reproSeq)
260		return nil, nil
261	}
262	mgr.RecvRepros++
263	mgr.reproSeq = minSeq
264	saveSeqFile(mgr.reproSeqFile, mgr.reproSeq)
265	return repro, nil
266}
267
268func (st *State) pendingInputs(mgr *Manager) ([][]byte, int, error) {
269	if mgr.corpusSeq == st.corpusSeq {
270		return nil, 0, nil
271	}
272	var records []db.Record
273	for key, rec := range st.Corpus.Records {
274		if mgr.corpusSeq >= rec.Seq {
275			continue
276		}
277		if _, ok := mgr.Corpus.Records[key]; ok {
278			continue
279		}
280		calls, err := prog.CallSet(rec.Val)
281		if err != nil {
282			return nil, 0, fmt.Errorf("failed to extract call set: %v\nprogram: %s", err, rec.Val)
283		}
284		if !managerSupportsAllCalls(mgr.Calls, calls) {
285			continue
286		}
287		records = append(records, rec)
288	}
289	maxSeq := st.corpusSeq
290	more := 0
291	// Send at most that many records (rounded up to next seq number).
292	const maxRecords = 100
293	if len(records) > maxRecords {
294		sort.Sort(recordSeqSorter(records))
295		pos := maxRecords
296		maxSeq = records[pos].Seq
297		for pos+1 < len(records) && records[pos+1].Seq == maxSeq {
298			pos++
299		}
300		pos++
301		more = len(records) - pos
302		records = records[:pos]
303	}
304	progs := make([][]byte, len(records))
305	for _, rec := range records {
306		progs = append(progs, rec.Val)
307	}
308	mgr.corpusSeq = maxSeq
309	saveSeqFile(mgr.corpusSeqFile, mgr.corpusSeq)
310	return progs, more, nil
311}
312
313func (st *State) addInputs(mgr *Manager, inputs [][]byte) {
314	if len(inputs) == 0 {
315		return
316	}
317	st.corpusSeq++
318	for _, input := range inputs {
319		st.addInput(mgr, input)
320	}
321	if err := mgr.Corpus.Flush(); err != nil {
322		log.Logf(0, "failed to flush corpus database: %v", err)
323	}
324	if err := st.Corpus.Flush(); err != nil {
325		log.Logf(0, "failed to flush corpus database: %v", err)
326	}
327}
328
329func (st *State) addInput(mgr *Manager, input []byte) {
330	if _, err := prog.CallSet(input); err != nil {
331		log.Logf(0, "manager %v: failed to extract call set: %v, program:\n%v", mgr.name, err, string(input))
332		return
333	}
334	sig := hash.String(input)
335	mgr.Corpus.Save(sig, nil, 0)
336	if _, ok := st.Corpus.Records[sig]; !ok {
337		st.Corpus.Save(sig, input, st.corpusSeq)
338	}
339}
340
341func (st *State) purgeCorpus() {
342	used := make(map[string]bool)
343	for _, mgr := range st.Managers {
344		for sig := range mgr.Corpus.Records {
345			used[sig] = true
346		}
347	}
348	for key := range st.Corpus.Records {
349		if used[key] {
350			continue
351		}
352		st.Corpus.Delete(key)
353	}
354	if err := st.Corpus.Flush(); err != nil {
355		log.Logf(0, "failed to flush corpus database: %v", err)
356	}
357}
358
359func managerSupportsAllCalls(mgr, prog map[string]struct{}) bool {
360	for c := range prog {
361		if _, ok := mgr[c]; !ok {
362			return false
363		}
364	}
365	return true
366}
367
368func writeFile(name string, data []byte) {
369	if err := osutil.WriteFile(name, data); err != nil {
370		log.Logf(0, "failed to write file %v: %v", name, err)
371	}
372}
373
374func saveSeqFile(filename string, seq uint64) {
375	writeFile(filename, []byte(fmt.Sprint(seq)))
376}
377
378func loadSeqFile(filename string) uint64 {
379	str, _ := ioutil.ReadFile(filename)
380	seq, _ := strconv.ParseUint(string(str), 10, 64)
381	return seq
382}
383
384type recordSeqSorter []db.Record
385
386func (a recordSeqSorter) Len() int {
387	return len(a)
388}
389
390func (a recordSeqSorter) Less(i, j int) bool {
391	return a[i].Seq < a[j].Seq
392}
393
394func (a recordSeqSorter) Swap(i, j int) {
395	a[i], a[j] = a[j], a[i]
396}
397