1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package isolated
5
6import (
7	"fmt"
8	"io"
9	"io/ioutil"
10	"os"
11	"path/filepath"
12	"strconv"
13	"strings"
14	"time"
15
16	"github.com/google/syzkaller/pkg/config"
17	"github.com/google/syzkaller/pkg/log"
18	"github.com/google/syzkaller/pkg/osutil"
19	"github.com/google/syzkaller/vm/vmimpl"
20)
21
22func init() {
23	vmimpl.Register("isolated", ctor)
24}
25
26type Config struct {
27	Targets      []string `json:"targets"`       // target machines: (hostname|ip)(:port)?
28	TargetDir    string   `json:"target_dir"`    // directory to copy/run on target
29	TargetReboot bool     `json:"target_reboot"` // reboot target on repair
30}
31
32type Pool struct {
33	env *vmimpl.Env
34	cfg *Config
35}
36
37type instance struct {
38	cfg         *Config
39	os          string
40	targetAddr  string
41	targetPort  int
42	closed      chan bool
43	debug       bool
44	sshUser     string
45	sshKey      string
46	forwardPort int
47}
48
49func ctor(env *vmimpl.Env) (vmimpl.Pool, error) {
50	cfg := &Config{}
51	if err := config.LoadData(env.Config, cfg); err != nil {
52		return nil, err
53	}
54	if len(cfg.Targets) == 0 {
55		return nil, fmt.Errorf("config param targets is empty")
56	}
57	if cfg.TargetDir == "" {
58		return nil, fmt.Errorf("config param target_dir is empty")
59	}
60	for _, target := range cfg.Targets {
61		if _, _, err := splitTargetPort(target); err != nil {
62			return nil, fmt.Errorf("bad target %q: %v", target, err)
63		}
64	}
65	if env.Debug {
66		cfg.Targets = cfg.Targets[:1]
67	}
68	pool := &Pool{
69		cfg: cfg,
70		env: env,
71	}
72	return pool, nil
73}
74
75func (pool *Pool) Count() int {
76	return len(pool.cfg.Targets)
77}
78
79func (pool *Pool) Create(workdir string, index int) (vmimpl.Instance, error) {
80	targetAddr, targetPort, _ := splitTargetPort(pool.cfg.Targets[index])
81	inst := &instance{
82		cfg:        pool.cfg,
83		os:         pool.env.OS,
84		targetAddr: targetAddr,
85		targetPort: targetPort,
86		closed:     make(chan bool),
87		debug:      pool.env.Debug,
88		sshUser:    pool.env.SSHUser,
89		sshKey:     pool.env.SSHKey,
90	}
91	closeInst := inst
92	defer func() {
93		if closeInst != nil {
94			closeInst.Close()
95		}
96	}()
97	if err := inst.repair(); err != nil {
98		return nil, err
99	}
100
101	// Create working dir if doesn't exist.
102	inst.ssh("mkdir -p '" + inst.cfg.TargetDir + "'")
103
104	// Remove temp files from previous runs.
105	inst.ssh("rm -rf '" + filepath.Join(inst.cfg.TargetDir, "*") + "'")
106
107	closeInst = nil
108	return inst, nil
109}
110
111func (inst *instance) Forward(port int) (string, error) {
112	if inst.forwardPort != 0 {
113		return "", fmt.Errorf("isolated: Forward port already set")
114	}
115	if port == 0 {
116		return "", fmt.Errorf("isolated: Forward port is zero")
117	}
118	inst.forwardPort = port
119	return fmt.Sprintf("127.0.0.1:%v", port), nil
120}
121
122func (inst *instance) ssh(command string) error {
123	if inst.debug {
124		log.Logf(0, "executing ssh %+v", command)
125	}
126
127	rpipe, wpipe, err := osutil.LongPipe()
128	if err != nil {
129		return err
130	}
131	// TODO(dvyukov): who is closing rpipe?
132
133	args := append(vmimpl.SSHArgs(inst.debug, inst.sshKey, inst.targetPort),
134		inst.sshUser+"@"+inst.targetAddr, command)
135	if inst.debug {
136		log.Logf(0, "running command: ssh %#v", args)
137	}
138	cmd := osutil.Command("ssh", args...)
139	cmd.Stdout = wpipe
140	cmd.Stderr = wpipe
141	if err := cmd.Start(); err != nil {
142		wpipe.Close()
143		return err
144	}
145	wpipe.Close()
146
147	done := make(chan bool)
148	go func() {
149		select {
150		case <-time.After(time.Second * 30):
151			if inst.debug {
152				log.Logf(0, "ssh hanged")
153			}
154			cmd.Process.Kill()
155		case <-done:
156		}
157	}()
158	if err := cmd.Wait(); err != nil {
159		close(done)
160		out, _ := ioutil.ReadAll(rpipe)
161		if inst.debug {
162			log.Logf(0, "ssh failed: %v\n%s", err, out)
163		}
164		return fmt.Errorf("ssh %+v failed: %v\n%s", args, err, out)
165	}
166	close(done)
167	if inst.debug {
168		log.Logf(0, "ssh returned")
169	}
170	return nil
171}
172
173func (inst *instance) repair() error {
174	log.Logf(2, "isolated: trying to ssh")
175	if err := inst.waitForSSH(30 * time.Minute); err == nil {
176		if inst.cfg.TargetReboot {
177			log.Logf(2, "isolated: trying to reboot")
178			inst.ssh("reboot") // reboot will return an error, ignore it
179			if err := inst.waitForReboot(5 * 60); err != nil {
180				log.Logf(2, "isolated: machine did not reboot")
181				return err
182			}
183			log.Logf(2, "isolated: rebooted wait for comeback")
184			if err := inst.waitForSSH(30 * time.Minute); err != nil {
185				log.Logf(2, "isolated: machine did not comeback")
186				return err
187			}
188			log.Logf(2, "isolated: reboot succeeded")
189		} else {
190			log.Logf(2, "isolated: ssh succeeded")
191		}
192	} else {
193		log.Logf(2, "isolated: ssh failed")
194		return fmt.Errorf("SSH failed")
195	}
196
197	return nil
198}
199
200func (inst *instance) waitForSSH(timeout time.Duration) error {
201	return vmimpl.WaitForSSH(inst.debug, timeout, inst.targetAddr, inst.sshKey, inst.sshUser,
202		inst.os, inst.targetPort)
203}
204
205func (inst *instance) waitForReboot(timeout int) error {
206	var err error
207	start := time.Now()
208	for {
209		if !vmimpl.SleepInterruptible(time.Second) {
210			return fmt.Errorf("shutdown in progress")
211		}
212		// If it fails, then the reboot started
213		if err = inst.ssh("pwd"); err != nil {
214			return nil
215		}
216		if time.Since(start).Seconds() > float64(timeout) {
217			break
218		}
219	}
220	return fmt.Errorf("isolated: the machine did not reboot on repair")
221}
222
223func (inst *instance) Close() {
224	close(inst.closed)
225}
226
227func (inst *instance) Copy(hostSrc string) (string, error) {
228	baseName := filepath.Base(hostSrc)
229	vmDst := filepath.Join(inst.cfg.TargetDir, baseName)
230	inst.ssh("pkill -9 '" + baseName + "'; rm -f '" + vmDst + "'")
231	args := append(vmimpl.SCPArgs(inst.debug, inst.sshKey, inst.targetPort),
232		hostSrc, inst.sshUser+"@"+inst.targetAddr+":"+vmDst)
233	cmd := osutil.Command("scp", args...)
234	if inst.debug {
235		log.Logf(0, "running command: scp %#v", args)
236		cmd.Stdout = os.Stdout
237		cmd.Stderr = os.Stdout
238	}
239	if err := cmd.Start(); err != nil {
240		return "", err
241	}
242	done := make(chan bool)
243	go func() {
244		select {
245		case <-time.After(3 * time.Minute):
246			cmd.Process.Kill()
247		case <-done:
248		}
249	}()
250	err := cmd.Wait()
251	close(done)
252	if err != nil {
253		return "", err
254	}
255	return vmDst, nil
256}
257
258func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
259	<-chan []byte, <-chan error, error) {
260	args := append(vmimpl.SSHArgs(inst.debug, inst.sshKey, inst.targetPort), inst.sshUser+"@"+inst.targetAddr)
261	dmesg, err := vmimpl.OpenRemoteConsole("ssh", args...)
262	if err != nil {
263		return nil, nil, err
264	}
265
266	rpipe, wpipe, err := osutil.LongPipe()
267	if err != nil {
268		dmesg.Close()
269		return nil, nil, err
270	}
271
272	args = vmimpl.SSHArgs(inst.debug, inst.sshKey, inst.targetPort)
273	// Forward target port as part of the ssh connection (reverse proxy)
274	if inst.forwardPort != 0 {
275		proxy := fmt.Sprintf("%v:127.0.0.1:%v", inst.forwardPort, inst.forwardPort)
276		args = append(args, "-R", proxy)
277	}
278	args = append(args, inst.sshUser+"@"+inst.targetAddr, "cd "+inst.cfg.TargetDir+" && exec "+command)
279	log.Logf(0, "running command: ssh %#v", args)
280	if inst.debug {
281		log.Logf(0, "running command: ssh %#v", args)
282	}
283	cmd := osutil.Command("ssh", args...)
284	cmd.Stdout = wpipe
285	cmd.Stderr = wpipe
286	if err := cmd.Start(); err != nil {
287		dmesg.Close()
288		rpipe.Close()
289		wpipe.Close()
290		return nil, nil, err
291	}
292	wpipe.Close()
293
294	var tee io.Writer
295	if inst.debug {
296		tee = os.Stdout
297	}
298	merger := vmimpl.NewOutputMerger(tee)
299	merger.Add("dmesg", dmesg)
300	merger.Add("ssh", rpipe)
301
302	return vmimpl.Multiplex(cmd, merger, dmesg, timeout, stop, inst.closed, inst.debug)
303}
304
305func (inst *instance) Diagnose() bool {
306	return false
307}
308
309func splitTargetPort(addr string) (string, int, error) {
310	target := addr
311	port := 22
312	if colonPos := strings.Index(addr, ":"); colonPos != -1 {
313		p, err := strconv.ParseUint(addr[colonPos+1:], 10, 16)
314		if err != nil {
315			return "", 0, err
316		}
317		target = addr[:colonPos]
318		port = int(p)
319	}
320	if target == "" {
321		return "", 0, fmt.Errorf("target is empty")
322	}
323	return target, port, nil
324}
325