1// Copyright 2021 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18	"bytes"
19	"io"
20	"testing"
21	"time"
22)
23
24func Test_runWithTimeout(t *testing.T) {
25	type args struct {
26		command      string
27		args         []string
28		timeout      time.Duration
29		onTimeoutCmd string
30		stdin        io.Reader
31	}
32	tests := []struct {
33		name       string
34		args       args
35		wantStdout string
36		wantStderr string
37		wantErr    bool
38	}{
39		{
40			name: "no timeout",
41			args: args{
42				command: "echo",
43				args:    []string{"foo"},
44			},
45			wantStdout: "foo\n",
46		},
47		{
48			name: "timeout not reached",
49			args: args{
50				command: "echo",
51				args:    []string{"foo"},
52				timeout: 1 * time.Second,
53			},
54			wantStdout: "foo\n",
55		},
56		{
57			name: "timed out",
58			args: args{
59				command: "sh",
60				args:    []string{"-c", "sleep 1 && echo foo"},
61				timeout: 1 * time.Millisecond,
62			},
63			wantErr: true,
64		},
65		{
66			name: "on_timeout command",
67			args: args{
68				command:      "sh",
69				args:         []string{"-c", "sleep 1 && echo foo"},
70				timeout:      1 * time.Millisecond,
71				onTimeoutCmd: "echo bar",
72			},
73			wantStdout: "bar\n",
74			wantErr:    true,
75		},
76	}
77	for _, tt := range tests {
78		t.Run(tt.name, func(t *testing.T) {
79			stdout := &bytes.Buffer{}
80			stderr := &bytes.Buffer{}
81			err := runWithTimeout(tt.args.command, tt.args.args, tt.args.timeout, tt.args.onTimeoutCmd, tt.args.stdin, stdout, stderr)
82			if (err != nil) != tt.wantErr {
83				t.Errorf("runWithTimeout() error = %v, wantErr %v", err, tt.wantErr)
84				return
85			}
86			if gotStdout := stdout.String(); gotStdout != tt.wantStdout {
87				t.Errorf("runWithTimeout() gotStdout = %v, want %v", gotStdout, tt.wantStdout)
88			}
89			if gotStderr := stderr.String(); gotStderr != tt.wantStderr {
90				t.Errorf("runWithTimeout() gotStderr = %v, want %v", gotStderr, tt.wantStderr)
91			}
92		})
93	}
94}
95