1// Copyright 2017 syzkaller project authors. All rights reserved. 2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4// Package db implements a simple key-value database. 5// The database is cached in memory and mirrored on disk. 6// It is used to store corpus in syz-manager and syz-hub. 7// The database strives to minimize number of disk accesses 8// as they can be slow in virtualized environments (GCE). 9package db 10 11import ( 12 "bufio" 13 "bytes" 14 "compress/flate" 15 "encoding/binary" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "os" 20 21 "github.com/google/syzkaller/pkg/log" 22 "github.com/google/syzkaller/pkg/osutil" 23) 24 25type DB struct { 26 Version uint64 // arbitrary user version (0 for new database) 27 Records map[string]Record // in-memory cache, must not be modified directly 28 29 filename string 30 uncompacted int // number of records in the file 31 pending *bytes.Buffer // pending writes to the file 32} 33 34type Record struct { 35 Val []byte 36 Seq uint64 37} 38 39func Open(filename string) (*DB, error) { 40 db := &DB{ 41 filename: filename, 42 } 43 f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, osutil.DefaultFilePerm) 44 if err != nil { 45 return nil, err 46 } 47 db.Version, db.Records, db.uncompacted = deserializeDB(bufio.NewReader(f)) 48 f.Close() 49 if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) { 50 if err := db.compact(); err != nil { 51 return nil, err 52 } 53 } 54 return db, nil 55} 56 57func (db *DB) Save(key string, val []byte, seq uint64) { 58 if seq == seqDeleted { 59 panic("reserved seq") 60 } 61 if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) { 62 return 63 } 64 db.Records[key] = Record{val, seq} 65 db.serialize(key, val, seq) 66 db.uncompacted++ 67} 68 69func (db *DB) Delete(key string) { 70 if _, ok := db.Records[key]; !ok { 71 return 72 } 73 delete(db.Records, key) 74 db.serialize(key, nil, seqDeleted) 75 db.uncompacted++ 76} 77 78func (db *DB) Flush() error { 79 if db.uncompacted/10*9 > len(db.Records) { 80 return db.compact() 81 } 82 if db.pending == nil { 83 return nil 84 } 85 f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, osutil.DefaultFilePerm) 86 if err != nil { 87 return err 88 } 89 defer f.Close() 90 if _, err := f.Write(db.pending.Bytes()); err != nil { 91 return err 92 } 93 db.pending = nil 94 return nil 95} 96 97func (db *DB) BumpVersion(version uint64) error { 98 if db.Version == version { 99 return db.Flush() 100 } 101 db.Version = version 102 return db.compact() 103} 104 105func (db *DB) compact() error { 106 buf := new(bytes.Buffer) 107 serializeHeader(buf, db.Version) 108 for key, rec := range db.Records { 109 serializeRecord(buf, key, rec.Val, rec.Seq) 110 } 111 f, err := os.Create(db.filename + ".tmp") 112 if err != nil { 113 return err 114 } 115 defer f.Close() 116 if _, err := f.Write(buf.Bytes()); err != nil { 117 return err 118 } 119 f.Close() 120 if err := os.Rename(f.Name(), db.filename); err != nil { 121 return err 122 } 123 db.uncompacted = len(db.Records) 124 db.pending = nil 125 return nil 126} 127 128func (db *DB) serialize(key string, val []byte, seq uint64) { 129 if db.pending == nil { 130 db.pending = new(bytes.Buffer) 131 } 132 serializeRecord(db.pending, key, val, seq) 133} 134 135const ( 136 dbMagic = uint32(0xbaddb) 137 recMagic = uint32(0xfee1bad) 138 curVersion = uint32(2) 139 seqDeleted = ^uint64(0) 140) 141 142func serializeHeader(w *bytes.Buffer, version uint64) { 143 binary.Write(w, binary.LittleEndian, dbMagic) 144 binary.Write(w, binary.LittleEndian, curVersion) 145 binary.Write(w, binary.LittleEndian, version) 146} 147 148func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) { 149 binary.Write(w, binary.LittleEndian, recMagic) 150 binary.Write(w, binary.LittleEndian, uint32(len(key))) 151 w.WriteString(key) 152 binary.Write(w, binary.LittleEndian, seq) 153 if seq == seqDeleted { 154 if len(val) != 0 { 155 panic("deleting record with value") 156 } 157 return 158 } 159 if len(val) == 0 { 160 binary.Write(w, binary.LittleEndian, uint32(len(val))) 161 } else { 162 lenPos := len(w.Bytes()) 163 binary.Write(w, binary.LittleEndian, uint32(0)) 164 startPos := len(w.Bytes()) 165 fw, err := flate.NewWriter(w, flate.BestCompression) 166 if err != nil { 167 panic(err) 168 } 169 if _, err := fw.Write(val); err != nil { 170 panic(err) 171 } 172 fw.Close() 173 binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos)) 174 } 175} 176 177func deserializeDB(r *bufio.Reader) (version uint64, records map[string]Record, uncompacted int) { 178 records = make(map[string]Record) 179 ver, err := deserializeHeader(r) 180 if err != nil { 181 log.Logf(0, "failed to deserialize database header: %v", err) 182 return 183 } 184 version = ver 185 for { 186 key, val, seq, err := deserializeRecord(r) 187 if err == io.EOF { 188 return 189 } 190 if err != nil { 191 log.Logf(0, "failed to deserialize database record: %v", err) 192 return 193 } 194 uncompacted++ 195 if seq == seqDeleted { 196 delete(records, key) 197 } else { 198 records[key] = Record{val, seq} 199 } 200 } 201} 202 203func deserializeHeader(r *bufio.Reader) (uint64, error) { 204 var magic, ver uint32 205 if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { 206 if err == io.EOF { 207 return 0, nil 208 } 209 return 0, err 210 } 211 if magic != dbMagic { 212 return 0, fmt.Errorf("bad db header: 0x%x", magic) 213 } 214 if err := binary.Read(r, binary.LittleEndian, &ver); err != nil { 215 return 0, err 216 } 217 if ver == 0 || ver > curVersion { 218 return 0, fmt.Errorf("bad db version: %v", ver) 219 } 220 var userVer uint64 221 if ver >= 2 { 222 if err := binary.Read(r, binary.LittleEndian, &userVer); err != nil { 223 return 0, err 224 } 225 } 226 return userVer, nil 227} 228 229func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) { 230 var magic uint32 231 if err = binary.Read(r, binary.LittleEndian, &magic); err != nil { 232 return 233 } 234 if magic != recMagic { 235 err = fmt.Errorf("bad record header: 0x%x", magic) 236 return 237 } 238 var keyLen uint32 239 if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil { 240 return 241 } 242 keyBuf := make([]byte, keyLen) 243 if _, err = io.ReadFull(r, keyBuf); err != nil { 244 return 245 } 246 key = string(keyBuf) 247 if err = binary.Read(r, binary.LittleEndian, &seq); err != nil { 248 return 249 } 250 if seq == seqDeleted { 251 return 252 } 253 var valLen uint32 254 if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil { 255 return 256 } 257 if valLen != 0 { 258 fr := flate.NewReader(&io.LimitedReader{R: r, N: int64(valLen)}) 259 if val, err = ioutil.ReadAll(fr); err != nil { 260 return 261 } 262 fr.Close() 263 } 264 return 265} 266