1// Mostly copied from Go's src/cmd/gofmt:
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package main
7
8import (
9	"bytes"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"os"
15	"os/exec"
16	"path/filepath"
17	"strings"
18	"unicode"
19
20	"github.com/google/blueprint/parser"
21)
22
23var (
24	// main operation modes
25	list            = flag.Bool("l", false, "list files that would be modified by bpmodify")
26	write           = flag.Bool("w", false, "write result to (source) file instead of stdout")
27	doDiff          = flag.Bool("d", false, "display diffs instead of rewriting files")
28	sortLists       = flag.Bool("s", false, "sort touched lists, even if they were unsorted")
29	parameter       = flag.String("parameter", "deps", "name of parameter to modify on each module")
30	targetedModules = new(identSet)
31	addIdents       = new(identSet)
32	removeIdents    = new(identSet)
33)
34
35func init() {
36	flag.Var(targetedModules, "m", "comma or whitespace separated list of modules on which to operate")
37	flag.Var(addIdents, "a", "comma or whitespace separated list of identifiers to add")
38	flag.Var(removeIdents, "r", "comma or whitespace separated list of identifiers to remove")
39}
40
41var (
42	exitCode = 0
43)
44
45func report(err error) {
46	fmt.Fprintln(os.Stderr, err)
47	exitCode = 2
48}
49
50func usage() {
51	fmt.Fprintln(os.Stderr, "usage: bpmodify [flags] [path ...]")
52	flag.PrintDefaults()
53	os.Exit(2)
54}
55
56// If in == nil, the source is the contents of the file with the given filename.
57func processFile(filename string, in io.Reader, out io.Writer) error {
58	if in == nil {
59		f, err := os.Open(filename)
60		if err != nil {
61			return err
62		}
63		defer f.Close()
64		in = f
65	}
66
67	src, err := ioutil.ReadAll(in)
68	if err != nil {
69		return err
70	}
71
72	r := bytes.NewBuffer(src)
73
74	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
75	if len(errs) > 0 {
76		for _, err := range errs {
77			fmt.Fprintln(os.Stderr, err)
78		}
79		return fmt.Errorf("%d parsing errors", len(errs))
80	}
81
82	modified, errs := findModules(file)
83	if len(errs) > 0 {
84		for _, err := range errs {
85			fmt.Fprintln(os.Stderr, err)
86		}
87		fmt.Fprintln(os.Stderr, "continuing...")
88	}
89
90	if modified {
91		res, err := parser.Print(file)
92		if err != nil {
93			return err
94		}
95
96		if *list {
97			fmt.Fprintln(out, filename)
98		}
99		if *write {
100			err = ioutil.WriteFile(filename, res, 0644)
101			if err != nil {
102				return err
103			}
104		}
105		if *doDiff {
106			data, err := diff(src, res)
107			if err != nil {
108				return fmt.Errorf("computing diff: %s", err)
109			}
110			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
111			out.Write(data)
112		}
113
114		if !*list && !*write && !*doDiff {
115			_, err = out.Write(res)
116		}
117	}
118
119	return err
120}
121
122func findModules(file *parser.File) (modified bool, errs []error) {
123
124	for _, def := range file.Defs {
125		if module, ok := def.(*parser.Module); ok {
126			for _, prop := range module.Properties {
127				if prop.Name == "name" && prop.Value.Type() == parser.StringType {
128					if targetedModule(prop.Value.Eval().(*parser.String).Value) {
129						m, newErrs := processModule(module, prop.Name, file)
130						errs = append(errs, newErrs...)
131						modified = modified || m
132					}
133				}
134			}
135		}
136	}
137
138	return modified, errs
139}
140
141func processModule(module *parser.Module, moduleName string,
142	file *parser.File) (modified bool, errs []error) {
143
144	for _, prop := range module.Properties {
145		if prop.Name == *parameter {
146			modified, errs = processParameter(prop.Value, *parameter, moduleName, file)
147			return
148		}
149	}
150
151	return false, nil
152}
153
154func processParameter(value parser.Expression, paramName, moduleName string,
155	file *parser.File) (modified bool, errs []error) {
156	if _, ok := value.(*parser.Variable); ok {
157		return false, []error{fmt.Errorf("parameter %s in module %s is a variable, unsupported",
158			paramName, moduleName)}
159	}
160
161	if _, ok := value.(*parser.Operator); ok {
162		return false, []error{fmt.Errorf("parameter %s in module %s is an expression, unsupported",
163			paramName, moduleName)}
164	}
165
166	list, ok := value.(*parser.List)
167	if !ok {
168		return false, []error{fmt.Errorf("expected parameter %s in module %s to be list, found %s",
169			paramName, moduleName, value.Type().String())}
170	}
171
172	wasSorted := parser.ListIsSorted(list)
173
174	for _, a := range addIdents.idents {
175		m := parser.AddStringToList(list, a)
176		modified = modified || m
177	}
178
179	for _, r := range removeIdents.idents {
180		m := parser.RemoveStringFromList(list, r)
181		modified = modified || m
182	}
183
184	if (wasSorted || *sortLists) && modified {
185		parser.SortList(file, list)
186	}
187
188	return modified, nil
189}
190
191func targetedModule(name string) bool {
192	if targetedModules.all {
193		return true
194	}
195	for _, m := range targetedModules.idents {
196		if m == name {
197			return true
198		}
199	}
200
201	return false
202}
203
204func visitFile(path string, f os.FileInfo, err error) error {
205	if err == nil && f.Name() == "Blueprints" {
206		err = processFile(path, nil, os.Stdout)
207	}
208	if err != nil {
209		report(err)
210	}
211	return nil
212}
213
214func walkDir(path string) {
215	filepath.Walk(path, visitFile)
216}
217
218func main() {
219	flag.Parse()
220
221	if flag.NArg() == 0 {
222		if *write {
223			fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input")
224			exitCode = 2
225			return
226		}
227		if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
228			report(err)
229		}
230		return
231	}
232
233	if len(targetedModules.idents) == 0 {
234		report(fmt.Errorf("-m parameter is required"))
235		return
236	}
237
238	if len(addIdents.idents) == 0 && len(removeIdents.idents) == 0 {
239		report(fmt.Errorf("-a or -r parameter is required"))
240		return
241	}
242
243	for i := 0; i < flag.NArg(); i++ {
244		path := flag.Arg(i)
245		switch dir, err := os.Stat(path); {
246		case err != nil:
247			report(err)
248		case dir.IsDir():
249			walkDir(path)
250		default:
251			if err := processFile(path, nil, os.Stdout); err != nil {
252				report(err)
253			}
254		}
255	}
256}
257
258func diff(b1, b2 []byte) (data []byte, err error) {
259	f1, err := ioutil.TempFile("", "bpfmt")
260	if err != nil {
261		return
262	}
263	defer os.Remove(f1.Name())
264	defer f1.Close()
265
266	f2, err := ioutil.TempFile("", "bpfmt")
267	if err != nil {
268		return
269	}
270	defer os.Remove(f2.Name())
271	defer f2.Close()
272
273	f1.Write(b1)
274	f2.Write(b2)
275
276	data, err = exec.Command("diff", "-uw", f1.Name(), f2.Name()).CombinedOutput()
277	if len(data) > 0 {
278		// diff exits with a non-zero status when the files don't match.
279		// Ignore that failure as long as we get output.
280		err = nil
281	}
282	return
283
284}
285
286type identSet struct {
287	idents []string
288	all    bool
289}
290
291func (m *identSet) String() string {
292	return strings.Join(m.idents, ",")
293}
294
295func (m *identSet) Set(s string) error {
296	m.idents = strings.FieldsFunc(s, func(c rune) bool {
297		return unicode.IsSpace(c) || c == ','
298	})
299	if len(m.idents) == 1 && m.idents[0] == "*" {
300		m.all = true
301	}
302	return nil
303}
304
305func (m *identSet) Get() interface{} {
306	return m.idents
307}
308