1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package osutil
5
6import (
7	"bytes"
8	"fmt"
9	"io/ioutil"
10	"os"
11	"os/exec"
12	"path/filepath"
13	"time"
14)
15
16const (
17	DefaultDirPerm  = 0755
18	DefaultFilePerm = 0644
19	DefaultExecPerm = 0755
20)
21
22// RunCmd runs "bin args..." in dir with timeout and returns its output.
23func RunCmd(timeout time.Duration, dir, bin string, args ...string) ([]byte, error) {
24	cmd := Command(bin, args...)
25	cmd.Dir = dir
26	return Run(timeout, cmd)
27}
28
29// Run runs cmd with the specified timeout.
30// Returns combined output. If the command fails, err includes output.
31func Run(timeout time.Duration, cmd *exec.Cmd) ([]byte, error) {
32	output := new(bytes.Buffer)
33	if cmd.Stdout == nil {
34		cmd.Stdout = output
35	}
36	if cmd.Stderr == nil {
37		cmd.Stderr = output
38	}
39	if err := cmd.Start(); err != nil {
40		return nil, fmt.Errorf("failed to start %v %+v: %v", cmd.Path, cmd.Args, err)
41	}
42	done := make(chan bool)
43	timedout := make(chan bool, 1)
44	timer := time.NewTimer(timeout)
45	go func() {
46		select {
47		case <-timer.C:
48			timedout <- true
49			cmd.Process.Kill()
50		case <-done:
51			timedout <- false
52			timer.Stop()
53		}
54	}()
55	err := cmd.Wait()
56	close(done)
57	if err != nil {
58		text := fmt.Sprintf("failed to run %q: %v", cmd.Args, err)
59		if <-timedout {
60			text = fmt.Sprintf("timedout %q", cmd.Args)
61		}
62		return nil, &VerboseError{
63			Title:  text,
64			Output: output.Bytes(),
65		}
66	}
67	return output.Bytes(), nil
68}
69
70// Command is similar to os/exec.Command, but also sets PDEATHSIG on linux.
71func Command(bin string, args ...string) *exec.Cmd {
72	cmd := exec.Command(bin, args...)
73	setPdeathsig(cmd)
74	return cmd
75}
76
77type VerboseError struct {
78	Title  string
79	Output []byte
80}
81
82func (err *VerboseError) Error() string {
83	if len(err.Output) == 0 {
84		return err.Title
85	}
86	return fmt.Sprintf("%v\n%s", err.Title, err.Output)
87}
88
89func PrependContext(ctx string, err error) error {
90	switch err1 := err.(type) {
91	case *VerboseError:
92		err1.Title = fmt.Sprintf("%v: %v", ctx, err1.Title)
93		return err1
94	default:
95		return fmt.Errorf("%v: %v", ctx, err)
96	}
97}
98
99// IsExist returns true if the file name exists.
100func IsExist(name string) bool {
101	_, err := os.Stat(name)
102	return err == nil
103}
104
105// IsAccessible checks if the file can be opened.
106func IsAccessible(name string) error {
107	if !IsExist(name) {
108		return fmt.Errorf("%v does not exist", name)
109	}
110	f, err := os.Open(name)
111	if err != nil {
112		return fmt.Errorf("%v can't be opened (%v)", name, err)
113	}
114	f.Close()
115	return nil
116}
117
118// FilesExist returns true if all files exist in dir.
119// Files are assumed to be relative names in slash notation.
120func FilesExist(dir string, files map[string]bool) bool {
121	for f, required := range files {
122		if !required {
123			continue
124		}
125		if !IsExist(filepath.Join(dir, filepath.FromSlash(f))) {
126			return false
127		}
128	}
129	return true
130}
131
132// CopyFiles copies files from srcDir to dstDir as atomically as possible.
133// Files are assumed to be relative names in slash notation.
134// All other files in dstDir are removed.
135func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
136	// Linux does not support atomic dir replace, so we copy to tmp dir first.
137	// Then remove dst dir and rename tmp to dst (as atomic as can get on Linux).
138	tmpDir := dstDir + ".tmp"
139	if err := os.RemoveAll(tmpDir); err != nil {
140		return err
141	}
142	if err := MkdirAll(tmpDir); err != nil {
143		return err
144	}
145	for f, required := range files {
146		src := filepath.Join(srcDir, filepath.FromSlash(f))
147		if !required && !IsExist(src) {
148			continue
149		}
150		dst := filepath.Join(tmpDir, filepath.FromSlash(f))
151		if err := MkdirAll(filepath.Dir(dst)); err != nil {
152			return err
153		}
154		if err := CopyFile(src, dst); err != nil {
155			return err
156		}
157	}
158	if err := os.RemoveAll(dstDir); err != nil {
159		return err
160	}
161	return os.Rename(tmpDir, dstDir)
162}
163
164// LinkFiles creates hard links for files from dstDir to srcDir.
165// Files are assumed to be relative names in slash notation.
166// All other files in dstDir are removed.
167func LinkFiles(srcDir, dstDir string, files map[string]bool) error {
168	if err := os.RemoveAll(dstDir); err != nil {
169		return err
170	}
171	if err := MkdirAll(dstDir); err != nil {
172		return err
173	}
174	for f, required := range files {
175		src := filepath.Join(srcDir, filepath.FromSlash(f))
176		if !required && !IsExist(src) {
177			continue
178		}
179		dst := filepath.Join(dstDir, filepath.FromSlash(f))
180		if err := MkdirAll(filepath.Dir(dst)); err != nil {
181			return err
182		}
183		if err := os.Link(src, dst); err != nil {
184			return err
185		}
186	}
187	return nil
188}
189
190func MkdirAll(dir string) error {
191	return os.MkdirAll(dir, DefaultDirPerm)
192}
193
194func WriteFile(filename string, data []byte) error {
195	return ioutil.WriteFile(filename, data, DefaultFilePerm)
196}
197
198func WriteExecFile(filename string, data []byte) error {
199	if err := ioutil.WriteFile(filename, data, DefaultExecPerm); err != nil {
200		return err
201	}
202	return os.Chmod(filename, DefaultExecPerm)
203}
204
205// TempFile creates a unique temp filename.
206// Note: the file already exists when the function returns.
207func TempFile(prefix string) (string, error) {
208	f, err := ioutil.TempFile("", prefix)
209	if err != nil {
210		return "", fmt.Errorf("failed to create temp file: %v", err)
211	}
212	f.Close()
213	return f.Name(), nil
214}
215
216// Return all files in a directory.
217func ListDir(dir string) ([]string, error) {
218	f, err := os.Open(dir)
219	if err != nil {
220		return nil, err
221	}
222	defer f.Close()
223	return f.Readdirnames(-1)
224}
225
226var wd string
227
228func init() {
229	var err error
230	wd, err = os.Getwd()
231	if err != nil {
232		panic(fmt.Sprintf("failed to get wd: %v", err))
233	}
234}
235
236func Abs(path string) string {
237	if wd1, err := os.Getwd(); err == nil && wd1 != wd {
238		panic("don't mess with wd in a concurrent program")
239	}
240	if path == "" || filepath.IsAbs(path) {
241		return path
242	}
243	return filepath.Join(wd, path)
244}
245