1// Copyright 2021 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// run_with_timeout is a utility that can kill a wrapped command after a configurable timeout,
16// optionally running a command to collect debugging information first.
17
18package main
19
20import (
21	"flag"
22	"fmt"
23	"io"
24	"os"
25	"os/exec"
26	"sync"
27	"syscall"
28	"time"
29)
30
31var (
32	timeout      = flag.Duration("timeout", 0, "time after which to kill command (example: 60s)")
33	onTimeoutCmd = flag.String("on_timeout", "", "command to run with `PID=<pid> sh -c` after timeout.")
34)
35
36func usage() {
37	fmt.Fprintf(os.Stderr, "usage: %s [--timeout N] [--on_timeout CMD] -- command [args...]\n", os.Args[0])
38	flag.PrintDefaults()
39	fmt.Fprintln(os.Stderr, "run_with_timeout is a utility that can kill a wrapped command after a configurable timeout,")
40	fmt.Fprintln(os.Stderr, "optionally running a command to collect debugging information first.")
41
42	os.Exit(2)
43}
44
45func main() {
46	flag.Usage = usage
47	flag.Parse()
48
49	if flag.NArg() < 1 {
50		fmt.Fprintln(os.Stderr, "command is required")
51		usage()
52	}
53
54	err := runWithTimeout(flag.Arg(0), flag.Args()[1:], *timeout, *onTimeoutCmd,
55		os.Stdin, os.Stdout, os.Stderr)
56	if err != nil {
57		if exitErr, ok := err.(*exec.ExitError); ok {
58			fmt.Fprintln(os.Stderr, "process exited with error:", exitErr.Error())
59		} else {
60			fmt.Fprintln(os.Stderr, "error:", err.Error())
61		}
62		os.Exit(1)
63	}
64}
65
66// concurrentWriter wraps a writer to make it thread-safe to call Write.
67type concurrentWriter struct {
68	w io.Writer
69	sync.Mutex
70}
71
72// Write writes the data to the wrapped writer with a lock to allow for concurrent calls.
73func (c *concurrentWriter) Write(data []byte) (n int, err error) {
74	c.Lock()
75	defer c.Unlock()
76	if c.w == nil {
77		return 0, nil
78	}
79	return c.w.Write(data)
80}
81
82// Close ends the concurrentWriter, causing future calls to Write to be no-ops.  It does not close
83// the underlying writer.
84func (c *concurrentWriter) Close() {
85	c.Lock()
86	defer c.Unlock()
87	c.w = nil
88}
89
90func runWithTimeout(command string, args []string, timeout time.Duration, onTimeoutCmdStr string,
91	stdin io.Reader, stdout, stderr io.Writer) error {
92	cmd := exec.Command(command, args...)
93
94	// Wrap the writers in a locking writer so that cmd and onTimeoutCmd don't try to write to
95	// stdout or stderr concurrently.
96	concurrentStdout := &concurrentWriter{w: stdout}
97	concurrentStderr := &concurrentWriter{w: stderr}
98	defer concurrentStdout.Close()
99	defer concurrentStderr.Close()
100
101	cmd.Stdin, cmd.Stdout, cmd.Stderr = stdin, concurrentStdout, concurrentStderr
102	err := cmd.Start()
103	if err != nil {
104		return err
105	}
106
107	// waitCh will signal the subprocess exited.
108	waitCh := make(chan error)
109	go func() {
110		waitCh <- cmd.Wait()
111	}()
112
113	// timeoutCh will signal the subprocess timed out if timeout was set.
114	var timeoutCh <-chan time.Time = make(chan time.Time)
115	if timeout > 0 {
116		timeoutCh = time.After(timeout)
117	}
118
119	select {
120	case err := <-waitCh:
121		if exitErr, ok := err.(*exec.ExitError); ok {
122			return fmt.Errorf("process exited with error: %w", exitErr)
123		}
124		return err
125	case <-timeoutCh:
126		// Continue below.
127	}
128
129	// Process timed out before exiting.
130	defer cmd.Process.Signal(syscall.SIGKILL)
131
132	if onTimeoutCmdStr != "" {
133		onTimeoutCmd := exec.Command("sh", "-c", onTimeoutCmdStr)
134		onTimeoutCmd.Stdin, onTimeoutCmd.Stdout, onTimeoutCmd.Stderr = stdin, concurrentStdout, concurrentStderr
135		onTimeoutCmd.Env = append(os.Environ(), fmt.Sprintf("PID=%d", cmd.Process.Pid))
136		err := onTimeoutCmd.Run()
137		if err != nil {
138			return fmt.Errorf("on_timeout command %q exited with error: %w", onTimeoutCmdStr, err)
139		}
140	}
141
142	return fmt.Errorf("timed out after %s", timeout.String())
143}
144