1// Mostly copied from Go's src/cmd/gofmt:
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package main
7
8import (
9	"bytes"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"os"
15	"os/exec"
16	"path/filepath"
17	"strings"
18	"unicode"
19
20	"github.com/google/blueprint/parser"
21)
22
23var (
24	// main operation modes
25	list             = flag.Bool("l", false, "list files that would be modified by bpmodify")
26	write            = flag.Bool("w", false, "write result to (source) file instead of stdout")
27	doDiff           = flag.Bool("d", false, "display diffs instead of rewriting files")
28	sortLists        = flag.Bool("s", false, "sort touched lists, even if they were unsorted")
29	targetedModules  = new(identSet)
30	targetedProperty = new(qualifiedProperty)
31	addIdents        = new(identSet)
32	removeIdents     = new(identSet)
33
34	setString *string
35)
36
37func init() {
38	flag.Var(targetedModules, "m", "comma or whitespace separated list of modules on which to operate")
39	flag.Var(targetedProperty, "parameter", "alias to -property=`name`")
40	flag.Var(targetedProperty, "property", "fully qualified `name` of property to modify (default \"deps\")")
41	flag.Var(addIdents, "a", "comma or whitespace separated list of identifiers to add")
42	flag.Var(removeIdents, "r", "comma or whitespace separated list of identifiers to remove")
43	flag.Var(stringPtrFlag{&setString}, "str", "set a string property")
44	flag.Usage = usage
45}
46
47var (
48	exitCode = 0
49)
50
51func report(err error) {
52	fmt.Fprintln(os.Stderr, err)
53	exitCode = 2
54}
55
56func usage() {
57	fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [flags] [path ...]\n", os.Args[0])
58	flag.PrintDefaults()
59}
60
61// If in == nil, the source is the contents of the file with the given filename.
62func processFile(filename string, in io.Reader, out io.Writer) error {
63	if in == nil {
64		f, err := os.Open(filename)
65		if err != nil {
66			return err
67		}
68		defer f.Close()
69		in = f
70	}
71
72	src, err := ioutil.ReadAll(in)
73	if err != nil {
74		return err
75	}
76
77	r := bytes.NewBuffer(src)
78
79	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
80	if len(errs) > 0 {
81		for _, err := range errs {
82			fmt.Fprintln(os.Stderr, err)
83		}
84		return fmt.Errorf("%d parsing errors", len(errs))
85	}
86
87	modified, errs := findModules(file)
88	if len(errs) > 0 {
89		for _, err := range errs {
90			fmt.Fprintln(os.Stderr, err)
91		}
92		fmt.Fprintln(os.Stderr, "continuing...")
93	}
94
95	if modified {
96		res, err := parser.Print(file)
97		if err != nil {
98			return err
99		}
100
101		if *list {
102			fmt.Fprintln(out, filename)
103		}
104		if *write {
105			err = ioutil.WriteFile(filename, res, 0644)
106			if err != nil {
107				return err
108			}
109		}
110		if *doDiff {
111			data, err := diff(src, res)
112			if err != nil {
113				return fmt.Errorf("computing diff: %s", err)
114			}
115			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
116			out.Write(data)
117		}
118
119		if !*list && !*write && !*doDiff {
120			_, err = out.Write(res)
121		}
122	}
123
124	return err
125}
126
127func findModules(file *parser.File) (modified bool, errs []error) {
128
129	for _, def := range file.Defs {
130		if module, ok := def.(*parser.Module); ok {
131			for _, prop := range module.Properties {
132				if prop.Name == "name" && prop.Value.Type() == parser.StringType {
133					if targetedModule(prop.Value.Eval().(*parser.String).Value) {
134						m, newErrs := processModule(module, prop.Name, file)
135						errs = append(errs, newErrs...)
136						modified = modified || m
137					}
138				}
139			}
140		}
141	}
142
143	return modified, errs
144}
145
146func processModule(module *parser.Module, moduleName string,
147	file *parser.File) (modified bool, errs []error) {
148	prop, err := getRecursiveProperty(module, targetedProperty.name(), targetedProperty.prefixes())
149	if err != nil {
150		return false, []error{err}
151	}
152	if prop == nil {
153		if len(addIdents.idents) > 0 {
154			// We are adding something to a non-existing list prop, so we need to create it first.
155			prop, modified, err = createRecursiveProperty(module, targetedProperty.name(), targetedProperty.prefixes(), &parser.List{})
156		} else if setString != nil {
157			// We setting a non-existent string property, so we need to create it first.
158			prop, modified, err = createRecursiveProperty(module, targetedProperty.name(), targetedProperty.prefixes(), &parser.String{})
159		} else {
160			// We cannot find an existing prop, and we aren't adding anything to the prop,
161			// which means we must be removing something from a non-existing prop,
162			// which means this is a noop.
163			return false, nil
164		}
165		if err != nil {
166			// Here should be unreachable, but still handle it for completeness.
167			return false, []error{err}
168		}
169	}
170	m, errs := processParameter(prop.Value, targetedProperty.String(), moduleName, file)
171	modified = modified || m
172	return modified, errs
173}
174
175func getRecursiveProperty(module *parser.Module, name string, prefixes []string) (prop *parser.Property, err error) {
176	prop, _, err = getOrCreateRecursiveProperty(module, name, prefixes, nil)
177	return prop, err
178}
179
180func createRecursiveProperty(module *parser.Module, name string, prefixes []string,
181	empty parser.Expression) (prop *parser.Property, modified bool, err error) {
182
183	return getOrCreateRecursiveProperty(module, name, prefixes, empty)
184}
185
186func getOrCreateRecursiveProperty(module *parser.Module, name string, prefixes []string,
187	empty parser.Expression) (prop *parser.Property, modified bool, err error) {
188	m := &module.Map
189	for i, prefix := range prefixes {
190		if prop, found := m.GetProperty(prefix); found {
191			if mm, ok := prop.Value.Eval().(*parser.Map); ok {
192				m = mm
193			} else {
194				// We've found a property in the AST and such property is not of type
195				// *parser.Map, which must mean we didn't modify the AST.
196				return nil, false, fmt.Errorf("Expected property %q to be a map, found %s",
197					strings.Join(prefixes[:i+1], "."), prop.Value.Type())
198			}
199		} else if empty != nil {
200			mm := &parser.Map{}
201			m.Properties = append(m.Properties, &parser.Property{Name: prefix, Value: mm})
202			m = mm
203			// We've created a new node in the AST. This means the m.GetProperty(name)
204			// check after this for loop must fail, because the node we inserted is an
205			// empty parser.Map, thus this function will return |modified| is true.
206		} else {
207			return nil, false, nil
208		}
209	}
210	if prop, found := m.GetProperty(name); found {
211		// We've found a property in the AST, which must mean we didn't modify the AST.
212		return prop, false, nil
213	} else if empty != nil {
214		prop = &parser.Property{Name: name, Value: empty}
215		m.Properties = append(m.Properties, prop)
216		return prop, true, nil
217	} else {
218		return nil, false, nil
219	}
220}
221
222func processParameter(value parser.Expression, paramName, moduleName string,
223	file *parser.File) (modified bool, errs []error) {
224	if _, ok := value.(*parser.Variable); ok {
225		return false, []error{fmt.Errorf("parameter %s in module %s is a variable, unsupported",
226			paramName, moduleName)}
227	}
228
229	if _, ok := value.(*parser.Operator); ok {
230		return false, []error{fmt.Errorf("parameter %s in module %s is an expression, unsupported",
231			paramName, moduleName)}
232	}
233
234	if len(addIdents.idents) > 0 || len(removeIdents.idents) > 0 {
235		list, ok := value.(*parser.List)
236		if !ok {
237			return false, []error{fmt.Errorf("expected parameter %s in module %s to be list, found %s",
238				paramName, moduleName, value.Type().String())}
239		}
240
241		wasSorted := parser.ListIsSorted(list)
242
243		for _, a := range addIdents.idents {
244			m := parser.AddStringToList(list, a)
245			modified = modified || m
246		}
247
248		for _, r := range removeIdents.idents {
249			m := parser.RemoveStringFromList(list, r)
250			modified = modified || m
251		}
252
253		if (wasSorted || *sortLists) && modified {
254			parser.SortList(file, list)
255		}
256	} else if setString != nil {
257		str, ok := value.(*parser.String)
258		if !ok {
259			return false, []error{fmt.Errorf("expected parameter %s in module %s to be string, found %s",
260				paramName, moduleName, value.Type().String())}
261		}
262
263		str.Value = *setString
264		modified = true
265	}
266
267	return modified, nil
268}
269
270func targetedModule(name string) bool {
271	if targetedModules.all {
272		return true
273	}
274	for _, m := range targetedModules.idents {
275		if m == name {
276			return true
277		}
278	}
279
280	return false
281}
282
283func visitFile(path string, f os.FileInfo, err error) error {
284	if err == nil && f.Name() == "Blueprints" {
285		err = processFile(path, nil, os.Stdout)
286	}
287	if err != nil {
288		report(err)
289	}
290	return nil
291}
292
293func walkDir(path string) {
294	filepath.Walk(path, visitFile)
295}
296
297func main() {
298	defer func() {
299		if err := recover(); err != nil {
300			report(fmt.Errorf("error: %s", err))
301		}
302		os.Exit(exitCode)
303	}()
304
305	flag.Parse()
306
307	if len(targetedProperty.parts) == 0 {
308		targetedProperty.Set("deps")
309	}
310
311	if flag.NArg() == 0 {
312		if *write {
313			report(fmt.Errorf("error: cannot use -w with standard input"))
314			return
315		}
316		if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
317			report(err)
318		}
319		return
320	}
321
322	if len(targetedModules.idents) == 0 {
323		report(fmt.Errorf("-m parameter is required"))
324		return
325	}
326
327	if len(addIdents.idents) == 0 && len(removeIdents.idents) == 0 && setString == nil {
328		report(fmt.Errorf("-a, -r or -str parameter is required"))
329		return
330	}
331
332	for i := 0; i < flag.NArg(); i++ {
333		path := flag.Arg(i)
334		switch dir, err := os.Stat(path); {
335		case err != nil:
336			report(err)
337		case dir.IsDir():
338			walkDir(path)
339		default:
340			if err := processFile(path, nil, os.Stdout); err != nil {
341				report(err)
342			}
343		}
344	}
345}
346
347func diff(b1, b2 []byte) (data []byte, err error) {
348	f1, err := ioutil.TempFile("", "bpfmt")
349	if err != nil {
350		return
351	}
352	defer os.Remove(f1.Name())
353	defer f1.Close()
354
355	f2, err := ioutil.TempFile("", "bpfmt")
356	if err != nil {
357		return
358	}
359	defer os.Remove(f2.Name())
360	defer f2.Close()
361
362	f1.Write(b1)
363	f2.Write(b2)
364
365	data, err = exec.Command("diff", "-uw", f1.Name(), f2.Name()).CombinedOutput()
366	if len(data) > 0 {
367		// diff exits with a non-zero status when the files don't match.
368		// Ignore that failure as long as we get output.
369		err = nil
370	}
371	return
372
373}
374
375type stringPtrFlag struct {
376	s **string
377}
378
379func (f stringPtrFlag) Set(s string) error {
380	*f.s = &s
381	return nil
382}
383
384func (f stringPtrFlag) String() string {
385	if f.s == nil || *f.s == nil {
386		return ""
387	}
388	return **f.s
389}
390
391type identSet struct {
392	idents []string
393	all    bool
394}
395
396func (m *identSet) String() string {
397	return strings.Join(m.idents, ",")
398}
399
400func (m *identSet) Set(s string) error {
401	m.idents = strings.FieldsFunc(s, func(c rune) bool {
402		return unicode.IsSpace(c) || c == ','
403	})
404	if len(m.idents) == 1 && m.idents[0] == "*" {
405		m.all = true
406	}
407	return nil
408}
409
410func (m *identSet) Get() interface{} {
411	return m.idents
412}
413
414type qualifiedProperty struct {
415	parts []string
416}
417
418var _ flag.Getter = (*qualifiedProperty)(nil)
419
420func (p *qualifiedProperty) name() string {
421	return p.parts[len(p.parts)-1]
422}
423
424func (p *qualifiedProperty) prefixes() []string {
425	return p.parts[:len(p.parts)-1]
426}
427
428func (p *qualifiedProperty) String() string {
429	return strings.Join(p.parts, ".")
430}
431
432func (p *qualifiedProperty) Set(s string) error {
433	p.parts = strings.Split(s, ".")
434	if len(p.parts) == 0 {
435		return fmt.Errorf("%q is not a valid property name", s)
436	}
437	for _, part := range p.parts {
438		if part == "" {
439			return fmt.Errorf("%q is not a valid property name", s)
440		}
441	}
442	return nil
443}
444
445func (p *qualifiedProperty) Get() interface{} {
446	return p.parts
447}
448