1// Copyright 2021 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17import ( 18 "bytes" 19 "io" 20 "testing" 21 "time" 22) 23 24func Test_runWithTimeout(t *testing.T) { 25 type args struct { 26 command string 27 args []string 28 timeout time.Duration 29 onTimeoutCmd string 30 stdin io.Reader 31 } 32 tests := []struct { 33 name string 34 args args 35 wantStdout string 36 wantStderr string 37 wantErr bool 38 }{ 39 { 40 name: "no timeout", 41 args: args{ 42 command: "echo", 43 args: []string{"foo"}, 44 }, 45 wantStdout: "foo\n", 46 }, 47 { 48 name: "timeout not reached", 49 args: args{ 50 command: "echo", 51 args: []string{"foo"}, 52 timeout: 1 * time.Second, 53 }, 54 wantStdout: "foo\n", 55 }, 56 { 57 name: "timed out", 58 args: args{ 59 command: "sh", 60 args: []string{"-c", "sleep 1 && echo foo"}, 61 timeout: 1 * time.Millisecond, 62 }, 63 wantErr: true, 64 }, 65 { 66 name: "on_timeout command", 67 args: args{ 68 command: "sh", 69 args: []string{"-c", "sleep 1 && echo foo"}, 70 timeout: 1 * time.Millisecond, 71 onTimeoutCmd: "echo bar", 72 }, 73 wantStdout: "bar\n", 74 wantErr: true, 75 }, 76 } 77 for _, tt := range tests { 78 t.Run(tt.name, func(t *testing.T) { 79 stdout := &bytes.Buffer{} 80 stderr := &bytes.Buffer{} 81 err := runWithTimeout(tt.args.command, tt.args.args, tt.args.timeout, tt.args.onTimeoutCmd, tt.args.stdin, stdout, stderr) 82 if (err != nil) != tt.wantErr { 83 t.Errorf("runWithTimeout() error = %v, wantErr %v", err, tt.wantErr) 84 return 85 } 86 if gotStdout := stdout.String(); gotStdout != tt.wantStdout { 87 t.Errorf("runWithTimeout() gotStdout = %v, want %v", gotStdout, tt.wantStdout) 88 } 89 if gotStderr := stderr.String(); gotStderr != tt.wantStderr { 90 t.Errorf("runWithTimeout() gotStderr = %v, want %v", gotStderr, tt.wantStderr) 91 } 92 }) 93 } 94} 95