1// Copyright 2017 syzkaller project authors. All rights reserved. 2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4package main 5 6import ( 7 "flag" 8 "fmt" 9 "io/ioutil" 10 "os" 11 "path/filepath" 12 "strconv" 13 "strings" 14 15 "github.com/google/syzkaller/pkg/db" 16 "github.com/google/syzkaller/pkg/hash" 17 "github.com/google/syzkaller/pkg/osutil" 18 "github.com/google/syzkaller/prog" 19 _ "github.com/google/syzkaller/sys" 20) 21 22func main() { 23 var ( 24 flagVersion = flag.Uint64("version", 0, "database version") 25 flagOS = flag.String("os", "", "target OS") 26 flagArch = flag.String("arch", "", "target arch") 27 ) 28 flag.Parse() 29 args := flag.Args() 30 if len(args) != 3 { 31 usage() 32 } 33 var target *prog.Target 34 if *flagOS != "" || *flagArch != "" { 35 var err error 36 target, err = prog.GetTarget(*flagOS, *flagArch) 37 if err != nil { 38 failf("failed to find target: %v", err) 39 } 40 } 41 switch args[0] { 42 case "pack": 43 pack(args[1], args[2], target, *flagVersion) 44 case "unpack": 45 unpack(args[1], args[2]) 46 default: 47 usage() 48 } 49} 50 51func usage() { 52 fmt.Fprintf(os.Stderr, "usage:\n") 53 fmt.Fprintf(os.Stderr, " syz-db pack dir corpus.db\n") 54 fmt.Fprintf(os.Stderr, " syz-db unpack corpus.db dir\n") 55 os.Exit(1) 56} 57 58func pack(dir, file string, target *prog.Target, version uint64) { 59 files, err := ioutil.ReadDir(dir) 60 if err != nil { 61 failf("failed to read dir: %v", err) 62 } 63 os.Remove(file) 64 db, err := db.Open(file) 65 if err != nil { 66 failf("failed to open database file: %v", err) 67 } 68 if err := db.BumpVersion(version); err != nil { 69 failf("failed to bump database version: %v", err) 70 } 71 for _, file := range files { 72 data, err := ioutil.ReadFile(filepath.Join(dir, file.Name())) 73 if err != nil { 74 failf("failed to read file %v: %v", file.Name(), err) 75 } 76 var seq uint64 77 key := file.Name() 78 if parts := strings.Split(file.Name(), "-"); len(parts) == 2 { 79 var err error 80 if seq, err = strconv.ParseUint(parts[1], 10, 64); err == nil { 81 key = parts[0] 82 } 83 } 84 if sig := hash.String(data); key != sig { 85 if target != nil { 86 p, err := target.Deserialize(data) 87 if err != nil { 88 failf("failed to deserialize %v: %v", file.Name(), err) 89 } 90 data = p.Serialize() 91 sig = hash.String(data) 92 } 93 fmt.Fprintf(os.Stderr, "fixing hash %v -> %v\n", key, sig) 94 key = sig 95 } 96 db.Save(key, data, seq) 97 } 98 if err := db.Flush(); err != nil { 99 failf("failed to save database file: %v", err) 100 } 101} 102 103func unpack(file, dir string) { 104 db, err := db.Open(file) 105 if err != nil { 106 failf("failed to open database: %v", err) 107 } 108 osutil.MkdirAll(dir) 109 for key, rec := range db.Records { 110 fname := filepath.Join(dir, key) 111 if rec.Seq != 0 { 112 fname += fmt.Sprintf("-%v", rec.Seq) 113 } 114 if err := osutil.WriteFile(fname, rec.Val); err != nil { 115 failf("failed to output file: %v", err) 116 } 117 } 118} 119 120func failf(msg string, args ...interface{}) { 121 fmt.Fprintf(os.Stderr, msg+"\n", args...) 122 os.Exit(1) 123} 124