1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4// Package db implements a simple key-value database.
5// The database is cached in memory and mirrored on disk.
6// It is used to store corpus in syz-manager and syz-hub.
7// The database strives to minimize number of disk accesses
8// as they can be slow in virtualized environments (GCE).
9package db
10
11import (
12	"bufio"
13	"bytes"
14	"compress/flate"
15	"encoding/binary"
16	"fmt"
17	"io"
18	"io/ioutil"
19	"os"
20
21	"github.com/google/syzkaller/pkg/log"
22	"github.com/google/syzkaller/pkg/osutil"
23)
24
25type DB struct {
26	Version uint64            // arbitrary user version (0 for new database)
27	Records map[string]Record // in-memory cache, must not be modified directly
28
29	filename    string
30	uncompacted int           // number of records in the file
31	pending     *bytes.Buffer // pending writes to the file
32}
33
34type Record struct {
35	Val []byte
36	Seq uint64
37}
38
39func Open(filename string) (*DB, error) {
40	db := &DB{
41		filename: filename,
42	}
43	f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, osutil.DefaultFilePerm)
44	if err != nil {
45		return nil, err
46	}
47	db.Version, db.Records, db.uncompacted = deserializeDB(bufio.NewReader(f))
48	f.Close()
49	if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) {
50		if err := db.compact(); err != nil {
51			return nil, err
52		}
53	}
54	return db, nil
55}
56
57func (db *DB) Save(key string, val []byte, seq uint64) {
58	if seq == seqDeleted {
59		panic("reserved seq")
60	}
61	if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) {
62		return
63	}
64	db.Records[key] = Record{val, seq}
65	db.serialize(key, val, seq)
66	db.uncompacted++
67}
68
69func (db *DB) Delete(key string) {
70	if _, ok := db.Records[key]; !ok {
71		return
72	}
73	delete(db.Records, key)
74	db.serialize(key, nil, seqDeleted)
75	db.uncompacted++
76}
77
78func (db *DB) Flush() error {
79	if db.uncompacted/10*9 > len(db.Records) {
80		return db.compact()
81	}
82	if db.pending == nil {
83		return nil
84	}
85	f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, osutil.DefaultFilePerm)
86	if err != nil {
87		return err
88	}
89	defer f.Close()
90	if _, err := f.Write(db.pending.Bytes()); err != nil {
91		return err
92	}
93	db.pending = nil
94	return nil
95}
96
97func (db *DB) BumpVersion(version uint64) error {
98	if db.Version == version {
99		return db.Flush()
100	}
101	db.Version = version
102	return db.compact()
103}
104
105func (db *DB) compact() error {
106	buf := new(bytes.Buffer)
107	serializeHeader(buf, db.Version)
108	for key, rec := range db.Records {
109		serializeRecord(buf, key, rec.Val, rec.Seq)
110	}
111	f, err := os.Create(db.filename + ".tmp")
112	if err != nil {
113		return err
114	}
115	defer f.Close()
116	if _, err := f.Write(buf.Bytes()); err != nil {
117		return err
118	}
119	f.Close()
120	if err := os.Rename(f.Name(), db.filename); err != nil {
121		return err
122	}
123	db.uncompacted = len(db.Records)
124	db.pending = nil
125	return nil
126}
127
128func (db *DB) serialize(key string, val []byte, seq uint64) {
129	if db.pending == nil {
130		db.pending = new(bytes.Buffer)
131	}
132	serializeRecord(db.pending, key, val, seq)
133}
134
135const (
136	dbMagic    = uint32(0xbaddb)
137	recMagic   = uint32(0xfee1bad)
138	curVersion = uint32(2)
139	seqDeleted = ^uint64(0)
140)
141
142func serializeHeader(w *bytes.Buffer, version uint64) {
143	binary.Write(w, binary.LittleEndian, dbMagic)
144	binary.Write(w, binary.LittleEndian, curVersion)
145	binary.Write(w, binary.LittleEndian, version)
146}
147
148func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) {
149	binary.Write(w, binary.LittleEndian, recMagic)
150	binary.Write(w, binary.LittleEndian, uint32(len(key)))
151	w.WriteString(key)
152	binary.Write(w, binary.LittleEndian, seq)
153	if seq == seqDeleted {
154		if len(val) != 0 {
155			panic("deleting record with value")
156		}
157		return
158	}
159	if len(val) == 0 {
160		binary.Write(w, binary.LittleEndian, uint32(len(val)))
161	} else {
162		lenPos := len(w.Bytes())
163		binary.Write(w, binary.LittleEndian, uint32(0))
164		startPos := len(w.Bytes())
165		fw, err := flate.NewWriter(w, flate.BestCompression)
166		if err != nil {
167			panic(err)
168		}
169		if _, err := fw.Write(val); err != nil {
170			panic(err)
171		}
172		fw.Close()
173		binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos))
174	}
175}
176
177func deserializeDB(r *bufio.Reader) (version uint64, records map[string]Record, uncompacted int) {
178	records = make(map[string]Record)
179	ver, err := deserializeHeader(r)
180	if err != nil {
181		log.Logf(0, "failed to deserialize database header: %v", err)
182		return
183	}
184	version = ver
185	for {
186		key, val, seq, err := deserializeRecord(r)
187		if err == io.EOF {
188			return
189		}
190		if err != nil {
191			log.Logf(0, "failed to deserialize database record: %v", err)
192			return
193		}
194		uncompacted++
195		if seq == seqDeleted {
196			delete(records, key)
197		} else {
198			records[key] = Record{val, seq}
199		}
200	}
201}
202
203func deserializeHeader(r *bufio.Reader) (uint64, error) {
204	var magic, ver uint32
205	if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
206		if err == io.EOF {
207			return 0, nil
208		}
209		return 0, err
210	}
211	if magic != dbMagic {
212		return 0, fmt.Errorf("bad db header: 0x%x", magic)
213	}
214	if err := binary.Read(r, binary.LittleEndian, &ver); err != nil {
215		return 0, err
216	}
217	if ver == 0 || ver > curVersion {
218		return 0, fmt.Errorf("bad db version: %v", ver)
219	}
220	var userVer uint64
221	if ver >= 2 {
222		if err := binary.Read(r, binary.LittleEndian, &userVer); err != nil {
223			return 0, err
224		}
225	}
226	return userVer, nil
227}
228
229func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) {
230	var magic uint32
231	if err = binary.Read(r, binary.LittleEndian, &magic); err != nil {
232		return
233	}
234	if magic != recMagic {
235		err = fmt.Errorf("bad record header: 0x%x", magic)
236		return
237	}
238	var keyLen uint32
239	if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil {
240		return
241	}
242	keyBuf := make([]byte, keyLen)
243	if _, err = io.ReadFull(r, keyBuf); err != nil {
244		return
245	}
246	key = string(keyBuf)
247	if err = binary.Read(r, binary.LittleEndian, &seq); err != nil {
248		return
249	}
250	if seq == seqDeleted {
251		return
252	}
253	var valLen uint32
254	if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil {
255		return
256	}
257	if valLen != 0 {
258		fr := flate.NewReader(&io.LimitedReader{R: r, N: int64(valLen)})
259		if val, err = ioutil.ReadAll(fr); err != nil {
260			return
261		}
262		fr.Close()
263	}
264	return
265}
266