1#!/usr/bin/env python
2#
3# Copyright 2019 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Tests for acloud.internal.lib.ssh."""
18
19import subprocess
20import unittest
21import threading
22import time
23
24from unittest import mock
25
26from acloud import errors
27from acloud.internal import constants
28from acloud.internal.lib import driver_test_lib
29from acloud.internal.lib import ssh
30
31
32class SshTest(driver_test_lib.BaseDriverTest):
33    """Test ssh class."""
34
35    FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea"
36    FAKE_SSH_USER = "fake_user"
37    FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1")
38    FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'"
39    FAKE_REPORT_INTERNAL_IP = True
40
41    def setUp(self):
42        """Set up the test."""
43        super(SshTest, self).setUp()
44        self.created_subprocess = mock.MagicMock()
45        self.created_subprocess.stdout = mock.MagicMock()
46        self.created_subprocess.stdout.readline = mock.MagicMock(return_value=b"")
47        self.created_subprocess.poll = mock.MagicMock(return_value=0)
48        self.created_subprocess.returncode = 0
49        self.created_subprocess.communicate = mock.MagicMock(return_value=
50                                                             ('', ''))
51
52    def testSSHExecuteWithRetry(self):
53        """test SSHExecuteWithRetry method."""
54        self.Patch(time, "sleep")
55        self.Patch(subprocess, "Popen",
56                   side_effect=subprocess.CalledProcessError(
57                       None, "ssh command fail."))
58        self.assertRaises(subprocess.CalledProcessError,
59                          ssh.ShellCmdWithRetry,
60                          "fake cmd")
61
62    def testGetBaseCmdWithInternalIP(self):
63        """Test get base command with internal ip."""
64        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
65                             user=self.FAKE_SSH_USER,
66                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
67                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
68        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
69                            "-o StrictHostKeyChecking=no -l fake_user 10.1.1.1")
70        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
71
72    def testGetBaseCmd(self):
73        """Test get base command."""
74        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
75        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
76                            "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1")
77        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
78
79        expected_scp_cmd = ("/usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
80                            "-o StrictHostKeyChecking=no")
81        self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd)
82
83    # pylint: disable=no-member
84    def testSshRunCmd(self):
85        """Test ssh run command."""
86        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
87        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
88        ssh_object.Run("command")
89        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
90                        "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1 command")
91        subprocess.Popen.assert_called_with(expected_cmd,
92                                            shell=True,
93                                            stderr=-2,
94                                            stdin=None,
95                                            stdout=-1,
96                                            universal_newlines=True)
97
98    def testSshRunCmdwithExtraArgs(self):
99        """test ssh rum command with extra command."""
100        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
101        ssh_object = ssh.Ssh(self.FAKE_IP,
102                             self.FAKE_SSH_USER,
103                             self.FAKE_SSH_PRIVATE_KEY_PATH,
104                             self.FAKE_EXTRA_ARGS_SSH)
105        ssh_object.Run("command")
106        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
107                        "-o StrictHostKeyChecking=no "
108                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
109                        "-l fake_user 1.1.1.1 command")
110        subprocess.Popen.assert_called_with(expected_cmd,
111                                            shell=True,
112                                            stderr=-2,
113                                            stdin=None,
114                                            stdout=-1,
115                                            universal_newlines=True)
116
117    def testScpPullFileCmd(self):
118        """Test scp pull file command."""
119        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
120        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
121        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
122        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
123                        "-o StrictHostKeyChecking=no fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
124        subprocess.Popen.assert_called_with(expected_cmd,
125                                            shell=True,
126                                            stderr=-2,
127                                            stdin=None,
128                                            stdout=-1,
129                                            universal_newlines=True)
130
131    def testScpPullFileCmdwithExtraArgs(self):
132        """Test scp pull file command."""
133        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
134        ssh_object = ssh.Ssh(self.FAKE_IP,
135                             self.FAKE_SSH_USER,
136                             self.FAKE_SSH_PRIVATE_KEY_PATH,
137                             self.FAKE_EXTRA_ARGS_SSH)
138        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
139        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
140                        "-o StrictHostKeyChecking=no "
141                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
142                        "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
143        subprocess.Popen.assert_called_with(expected_cmd,
144                                            shell=True,
145                                            stderr=-2,
146                                            stdin=None,
147                                            stdout=-1,
148                                            universal_newlines=True)
149
150    def testScpPushFileCmd(self):
151        """Test scp push file command."""
152        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
153        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
154        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
155        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
156                        "-o StrictHostKeyChecking=no /tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
157        subprocess.Popen.assert_called_with(expected_cmd,
158                                            shell=True,
159                                            stderr=-2,
160                                            stdin=None,
161                                            stdout=-1,
162                                            universal_newlines=True)
163
164    def testScpPushFileCmdwithExtraArgs(self):
165        """Test scp pull file command."""
166        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
167        ssh_object = ssh.Ssh(self.FAKE_IP,
168                             self.FAKE_SSH_USER,
169                             self.FAKE_SSH_PRIVATE_KEY_PATH,
170                             self.FAKE_EXTRA_ARGS_SSH)
171        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
172        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
173                        "-o StrictHostKeyChecking=no "
174                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
175                        "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
176        subprocess.Popen.assert_called_with(expected_cmd,
177                                            shell=True,
178                                            stderr=-2,
179                                            stdin=None,
180                                            stdout=-1,
181                                            universal_newlines=True)
182
183    # pylint: disable=protected-access
184    def testIPAddress(self):
185        """Test IP class to get ip address."""
186        # Internal ip case.
187        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
188                             user=self.FAKE_SSH_USER,
189                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
190                             report_internal_ip=True)
191        expected_ip = "10.1.1.1"
192        self.assertEqual(ssh_object._ip, expected_ip)
193
194        # External ip case.
195        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
196                             user=self.FAKE_SSH_USER,
197                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
198        expected_ip = "1.1.1.1"
199        self.assertEqual(ssh_object._ip, expected_ip)
200
201        # Only one ip case.
202        ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"),
203                             user=self.FAKE_SSH_USER,
204                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
205        expected_ip = "1.1.1.1"
206        self.assertEqual(ssh_object._ip, expected_ip)
207
208    def testWaitForSsh(self):
209        """Test WaitForSsh."""
210        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
211                             user=self.FAKE_SSH_USER,
212                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
213                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
214        self.Patch(ssh, "_SshCall", return_value=-1)
215        self.assertRaises(errors.DeviceConnectionError,
216                          ssh_object.WaitForSsh,
217                          timeout=1,
218                          max_retry=1)
219
220    def testSshCallWait(self):
221        """Test SshCallWait."""
222        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
223        self.Patch(threading, "Timer")
224        fake_cmd = "fake command"
225        ssh._SshCallWait(fake_cmd)
226        threading.Timer.assert_not_called()
227
228    def testSshCallWaitTimeout(self):
229        """Test SshCallWait with timeout."""
230        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
231        self.Patch(threading, "Timer")
232        fake_cmd = "fake command"
233        fake_timeout = 30
234        ssh._SshCallWait(fake_cmd, fake_timeout)
235        threading.Timer.assert_called_once()
236
237    def testSshCall(self):
238        """Test _SshCall."""
239        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
240        self.Patch(threading, "Timer")
241        fake_cmd = "fake command"
242        ssh._SshCall(fake_cmd)
243        threading.Timer.assert_not_called()
244
245    def testSshCallTimeout(self):
246        """Test SshCallWait with timeout."""
247        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
248        self.Patch(threading, "Timer")
249        fake_cmd = "fake command"
250        fake_timeout = 30
251        ssh._SshCall(fake_cmd, fake_timeout)
252        threading.Timer.assert_called_once()
253
254    def testSshLogOutput(self):
255        """Test _SshCall."""
256        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
257        self.Patch(threading, "Timer")
258        fake_cmd = "fake command"
259        ssh._SshLogOutput(fake_cmd)
260        threading.Timer.assert_not_called()
261
262    def testSshLogOutputTimeout(self):
263        """Test SshCallWait with timeout."""
264        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
265        self.Patch(threading, "Timer")
266        fake_cmd = "fake command"
267        fake_timeout = 30
268        ssh._SshLogOutput(fake_cmd, fake_timeout)
269        threading.Timer.assert_called_once()
270
271if __name__ == "__main__":
272    unittest.main()
273