1// Mostly copied from Go's src/cmd/gofmt:
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package main
7
8import (
9	"bytes"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"os"
15	"os/exec"
16	"path/filepath"
17
18	"github.com/google/blueprint/parser"
19)
20
21var (
22	// main operation modes
23	list                = flag.Bool("l", false, "list files whose formatting differs from bpfmt's")
24	overwriteSourceFile = flag.Bool("w", false, "write result to (source) file")
25	writeToStout        = flag.Bool("o", false, "write result to stdout")
26	doDiff              = flag.Bool("d", false, "display diffs instead of rewriting files")
27	sortLists           = flag.Bool("s", false, "sort arrays")
28)
29
30var (
31	exitCode = 0
32)
33
34func report(err error) {
35	fmt.Fprintln(os.Stderr, err)
36	exitCode = 2
37}
38
39func usage() {
40	usageViolation("")
41}
42
43func usageViolation(violation string) {
44	fmt.Fprintln(os.Stderr, violation)
45	fmt.Fprintln(os.Stderr, "usage: bpfmt [flags] [path ...]")
46	flag.PrintDefaults()
47	os.Exit(2)
48}
49
50func processFile(filename string, out io.Writer) error {
51	f, err := os.Open(filename)
52	if err != nil {
53		return err
54	}
55	defer f.Close()
56
57	return processReader(filename, f, out)
58}
59
60func processReader(filename string, in io.Reader, out io.Writer) error {
61	src, err := ioutil.ReadAll(in)
62	if err != nil {
63		return err
64	}
65
66	r := bytes.NewBuffer(src)
67
68	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
69	if len(errs) > 0 {
70		for _, err := range errs {
71			fmt.Fprintln(os.Stderr, err)
72		}
73		return fmt.Errorf("%d parsing errors", len(errs))
74	}
75
76	if *sortLists {
77		parser.SortLists(file)
78	}
79
80	res, err := parser.Print(file)
81	if err != nil {
82		return err
83	}
84
85	if !bytes.Equal(src, res) {
86		// formatting has changed
87		if *list {
88			fmt.Fprintln(out, filename)
89		}
90		if *overwriteSourceFile {
91			err = ioutil.WriteFile(filename, res, 0644)
92			if err != nil {
93				return err
94			}
95		}
96		if *doDiff {
97			data, err := diff(src, res)
98			if err != nil {
99				return fmt.Errorf("computing diff: %s", err)
100			}
101			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
102			out.Write(data)
103		}
104	}
105
106	if !*list && !*overwriteSourceFile && !*doDiff {
107		_, err = out.Write(res)
108	}
109
110	return err
111}
112
113func walkDir(path string) {
114	visitFile := func(path string, f os.FileInfo, err error) error {
115		if err == nil && f.Name() == "Blueprints" {
116			err = processFile(path, os.Stdout)
117		}
118		if err != nil {
119			report(err)
120		}
121		return nil
122	}
123
124	filepath.Walk(path, visitFile)
125}
126
127func main() {
128	flag.Usage = usage
129	flag.Parse()
130
131	if !*writeToStout && !*overwriteSourceFile && !*doDiff && !*list {
132		usageViolation("one of -d, -l, -o, or -w is required")
133	}
134
135	if flag.NArg() == 0 {
136		// file to parse is stdin
137		if *overwriteSourceFile {
138			fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input")
139			os.Exit(2)
140		}
141		if err := processReader("<standard input>", os.Stdin, os.Stdout); err != nil {
142			report(err)
143		}
144		os.Exit(exitCode)
145	}
146
147	for i := 0; i < flag.NArg(); i++ {
148		path := flag.Arg(i)
149		switch dir, err := os.Stat(path); {
150		case err != nil:
151			report(err)
152		case dir.IsDir():
153			walkDir(path)
154		default:
155			if err := processFile(path, os.Stdout); err != nil {
156				report(err)
157			}
158		}
159	}
160
161	os.Exit(exitCode)
162}
163
164func diff(b1, b2 []byte) (data []byte, err error) {
165	f1, err := ioutil.TempFile("", "bpfmt")
166	if err != nil {
167		return
168	}
169	defer os.Remove(f1.Name())
170	defer f1.Close()
171
172	f2, err := ioutil.TempFile("", "bpfmt")
173	if err != nil {
174		return
175	}
176	defer os.Remove(f2.Name())
177	defer f2.Close()
178
179	f1.Write(b1)
180	f2.Write(b2)
181
182	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
183	if len(data) > 0 {
184		// diff exits with a non-zero status when the files don't match.
185		// Ignore that failure as long as we get output.
186		err = nil
187	}
188	return
189
190}
191