1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package main
5
6import (
7	"flag"
8	"fmt"
9	"io/ioutil"
10	"os"
11	"path/filepath"
12	"strconv"
13	"strings"
14
15	"github.com/google/syzkaller/pkg/db"
16	"github.com/google/syzkaller/pkg/hash"
17	"github.com/google/syzkaller/pkg/osutil"
18	"github.com/google/syzkaller/prog"
19	_ "github.com/google/syzkaller/sys"
20)
21
22func main() {
23	var (
24		flagVersion = flag.Uint64("version", 0, "database version")
25		flagOS      = flag.String("os", "", "target OS")
26		flagArch    = flag.String("arch", "", "target arch")
27	)
28	flag.Parse()
29	args := flag.Args()
30	if len(args) != 3 {
31		usage()
32	}
33	var target *prog.Target
34	if *flagOS != "" || *flagArch != "" {
35		var err error
36		target, err = prog.GetTarget(*flagOS, *flagArch)
37		if err != nil {
38			failf("failed to find target: %v", err)
39		}
40	}
41	switch args[0] {
42	case "pack":
43		pack(args[1], args[2], target, *flagVersion)
44	case "unpack":
45		unpack(args[1], args[2])
46	default:
47		usage()
48	}
49}
50
51func usage() {
52	fmt.Fprintf(os.Stderr, "usage:\n")
53	fmt.Fprintf(os.Stderr, "  syz-db pack dir corpus.db\n")
54	fmt.Fprintf(os.Stderr, "  syz-db unpack corpus.db dir\n")
55	os.Exit(1)
56}
57
58func pack(dir, file string, target *prog.Target, version uint64) {
59	files, err := ioutil.ReadDir(dir)
60	if err != nil {
61		failf("failed to read dir: %v", err)
62	}
63	os.Remove(file)
64	db, err := db.Open(file)
65	if err != nil {
66		failf("failed to open database file: %v", err)
67	}
68	if err := db.BumpVersion(version); err != nil {
69		failf("failed to bump database version: %v", err)
70	}
71	for _, file := range files {
72		data, err := ioutil.ReadFile(filepath.Join(dir, file.Name()))
73		if err != nil {
74			failf("failed to read file %v: %v", file.Name(), err)
75		}
76		var seq uint64
77		key := file.Name()
78		if parts := strings.Split(file.Name(), "-"); len(parts) == 2 {
79			var err error
80			if seq, err = strconv.ParseUint(parts[1], 10, 64); err == nil {
81				key = parts[0]
82			}
83		}
84		if sig := hash.String(data); key != sig {
85			if target != nil {
86				p, err := target.Deserialize(data)
87				if err != nil {
88					failf("failed to deserialize %v: %v", file.Name(), err)
89				}
90				data = p.Serialize()
91				sig = hash.String(data)
92			}
93			fmt.Fprintf(os.Stderr, "fixing hash %v -> %v\n", key, sig)
94			key = sig
95		}
96		db.Save(key, data, seq)
97	}
98	if err := db.Flush(); err != nil {
99		failf("failed to save database file: %v", err)
100	}
101}
102
103func unpack(file, dir string) {
104	db, err := db.Open(file)
105	if err != nil {
106		failf("failed to open database: %v", err)
107	}
108	osutil.MkdirAll(dir)
109	for key, rec := range db.Records {
110		fname := filepath.Join(dir, key)
111		if rec.Seq != 0 {
112			fname += fmt.Sprintf("-%v", rec.Seq)
113		}
114		if err := osutil.WriteFile(fname, rec.Val); err != nil {
115			failf("failed to output file: %v", err)
116		}
117	}
118}
119
120func failf(msg string, args ...interface{}) {
121	fmt.Fprintf(os.Stderr, msg+"\n", args...)
122	os.Exit(1)
123}
124