1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package ast
5
6import (
7	"bytes"
8	"fmt"
9	"io"
10)
11
12func Format(desc *Description) []byte {
13	buf := new(bytes.Buffer)
14	FormatWriter(buf, desc)
15	return buf.Bytes()
16}
17
18func FormatWriter(w io.Writer, desc *Description) {
19	for _, n := range desc.Nodes {
20		s, ok := n.(serializer)
21		if !ok {
22			panic(fmt.Sprintf("unknown top level decl: %#v", n))
23		}
24		s.serialize(w)
25	}
26}
27
28func SerializeNode(n Node) string {
29	s, ok := n.(serializer)
30	if !ok {
31		panic(fmt.Sprintf("unknown node: %#v", n))
32	}
33	buf := new(bytes.Buffer)
34	s.serialize(buf)
35	return buf.String()
36}
37
38func FormatInt(v uint64, format IntFmt) string {
39	switch format {
40	case IntFmtDec:
41		return fmt.Sprint(v)
42	case IntFmtNeg:
43		return fmt.Sprint(int64(v))
44	case IntFmtHex:
45		return fmt.Sprintf("0x%x", v)
46	case IntFmtChar:
47		return fmt.Sprintf("'%c'", v)
48	default:
49		panic(fmt.Sprintf("unknown int format %v", format))
50	}
51}
52
53type serializer interface {
54	serialize(w io.Writer)
55}
56
57func (nl *NewLine) serialize(w io.Writer) {
58	fmt.Fprintf(w, "\n")
59}
60
61func (com *Comment) serialize(w io.Writer) {
62	fmt.Fprintf(w, "#%v\n", com.Text)
63}
64
65func (incl *Include) serialize(w io.Writer) {
66	fmt.Fprintf(w, "include <%v>\n", incl.File.Value)
67}
68
69func (inc *Incdir) serialize(w io.Writer) {
70	fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value)
71}
72
73func (def *Define) serialize(w io.Writer) {
74	fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value))
75}
76
77func (res *Resource) serialize(w io.Writer) {
78	fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base))
79	for i, v := range res.Values {
80		fmt.Fprintf(w, "%v%v", comma(i, ": "), fmtInt(v))
81	}
82	fmt.Fprintf(w, "\n")
83}
84
85func (typedef *TypeDef) serialize(w io.Writer) {
86	fmt.Fprintf(w, "type %v%v", typedef.Name.Name, fmtIdentList(typedef.Args))
87	if typedef.Type != nil {
88		fmt.Fprintf(w, " %v\n", fmtType(typedef.Type))
89	}
90	if typedef.Struct != nil {
91		typedef.Struct.serialize(w)
92	}
93}
94
95func (c *Call) serialize(w io.Writer) {
96	fmt.Fprintf(w, "%v(", c.Name.Name)
97	for i, a := range c.Args {
98		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtField(a))
99	}
100	fmt.Fprintf(w, ")")
101	if c.Ret != nil {
102		fmt.Fprintf(w, " %v", fmtType(c.Ret))
103	}
104	fmt.Fprintf(w, "\n")
105}
106
107func (str *Struct) serialize(w io.Writer) {
108	opening, closing := '{', '}'
109	if str.IsUnion {
110		opening, closing = '[', ']'
111	}
112	fmt.Fprintf(w, "%v %c\n", str.Name.Name, opening)
113	// Align all field types to the same column.
114	const tabWidth = 8
115	maxTabs := 0
116	for _, f := range str.Fields {
117		tabs := (len(f.Name.Name) + tabWidth) / tabWidth
118		if maxTabs < tabs {
119			maxTabs = tabs
120		}
121	}
122	for _, f := range str.Fields {
123		if f.NewBlock {
124			fmt.Fprintf(w, "\n")
125		}
126		for _, com := range f.Comments {
127			fmt.Fprintf(w, "#%v\n", com.Text)
128		}
129		fmt.Fprintf(w, "\t%v\t", f.Name.Name)
130		for tabs := len(f.Name.Name)/tabWidth + 1; tabs < maxTabs; tabs++ {
131			fmt.Fprintf(w, "\t")
132		}
133		fmt.Fprintf(w, "%v\n", fmtType(f.Type))
134	}
135	for _, com := range str.Comments {
136		fmt.Fprintf(w, "#%v\n", com.Text)
137	}
138	fmt.Fprintf(w, "%c", closing)
139	if attrs := fmtTypeList(str.Attrs); attrs != "" {
140		fmt.Fprintf(w, " %v", attrs)
141	}
142	fmt.Fprintf(w, "\n")
143}
144
145func (flags *IntFlags) serialize(w io.Writer) {
146	fmt.Fprintf(w, "%v = ", flags.Name.Name)
147	for i, v := range flags.Values {
148		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtInt(v))
149	}
150	fmt.Fprintf(w, "\n")
151}
152
153func (flags *StrFlags) serialize(w io.Writer) {
154	fmt.Fprintf(w, "%v = ", flags.Name.Name)
155	for i, v := range flags.Values {
156		fmt.Fprintf(w, "%v\"%v\"", comma(i, ""), v.Value)
157	}
158	fmt.Fprintf(w, "\n")
159}
160
161func fmtField(f *Field) string {
162	return fmt.Sprintf("%v %v", f.Name.Name, fmtType(f.Type))
163}
164
165func (n *Type) serialize(w io.Writer) {
166	w.Write([]byte(fmtType(n)))
167}
168
169func fmtType(t *Type) string {
170	v := ""
171	switch {
172	case t.Ident != "":
173		v = t.Ident
174	case t.HasString:
175		v = fmt.Sprintf("\"%v\"", t.String)
176	default:
177		v = FormatInt(t.Value, t.ValueFmt)
178	}
179	if t.HasColon {
180		switch {
181		case t.Ident2 != "":
182			v += fmt.Sprintf(":%v", t.Ident2)
183		default:
184			v += fmt.Sprintf(":%v", FormatInt(t.Value2, t.Value2Fmt))
185		}
186	}
187	v += fmtTypeList(t.Args)
188	return v
189}
190
191func fmtTypeList(args []*Type) string {
192	if len(args) == 0 {
193		return ""
194	}
195	w := new(bytes.Buffer)
196	fmt.Fprintf(w, "[")
197	for i, t := range args {
198		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtType(t))
199	}
200	fmt.Fprintf(w, "]")
201	return w.String()
202}
203
204func fmtIdentList(args []*Ident) string {
205	if len(args) == 0 {
206		return ""
207	}
208	w := new(bytes.Buffer)
209	fmt.Fprintf(w, "[")
210	for i, arg := range args {
211		fmt.Fprintf(w, "%v%v", comma(i, ""), arg.Name)
212	}
213	fmt.Fprintf(w, "]")
214	return w.String()
215}
216
217func fmtInt(i *Int) string {
218	switch {
219	case i.Ident != "":
220		return i.Ident
221	case i.CExpr != "":
222		return fmt.Sprintf("%v", i.CExpr)
223	default:
224		return FormatInt(i.Value, i.ValueFmt)
225	}
226}
227
228func comma(i int, or string) string {
229	if i == 0 {
230		return or
231	}
232	return ", "
233}
234