1#!/usr/bin/env python 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Tests for acloud.internal.lib.ssh.""" 18 19import subprocess 20import unittest 21import threading 22import time 23 24from unittest import mock 25 26from acloud import errors 27from acloud.internal import constants 28from acloud.internal.lib import driver_test_lib 29from acloud.internal.lib import ssh 30 31 32class SshTest(driver_test_lib.BaseDriverTest): 33 """Test ssh class.""" 34 35 FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea" 36 FAKE_SSH_USER = "fake_user" 37 FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1") 38 FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'" 39 FAKE_REPORT_INTERNAL_IP = True 40 41 def setUp(self): 42 """Set up the test.""" 43 super(SshTest, self).setUp() 44 self.created_subprocess = mock.MagicMock() 45 self.created_subprocess.stdout = mock.MagicMock() 46 self.created_subprocess.stdout.readline = mock.MagicMock(return_value=b"") 47 self.created_subprocess.poll = mock.MagicMock(return_value=0) 48 self.created_subprocess.returncode = 0 49 self.created_subprocess.communicate = mock.MagicMock(return_value= 50 ('', '')) 51 52 def testSSHExecuteWithRetry(self): 53 """test SSHExecuteWithRetry method.""" 54 self.Patch(time, "sleep") 55 self.Patch(subprocess, "Popen", 56 side_effect=subprocess.CalledProcessError( 57 None, "ssh command fail.")) 58 self.assertRaises(subprocess.CalledProcessError, 59 ssh.ShellCmdWithRetry, 60 "fake cmd") 61 62 def testGetBaseCmdWithInternalIP(self): 63 """Test get base command with internal ip.""" 64 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 65 user=self.FAKE_SSH_USER, 66 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 67 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 68 expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 69 "-o StrictHostKeyChecking=no -l fake_user 10.1.1.1") 70 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 71 72 def testGetBaseCmd(self): 73 """Test get base command.""" 74 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 75 expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 76 "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1") 77 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 78 79 expected_scp_cmd = ("/usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 80 "-o StrictHostKeyChecking=no") 81 self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd) 82 83 # pylint: disable=no-member 84 def testSshRunCmd(self): 85 """Test ssh run command.""" 86 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 87 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 88 ssh_object.Run("command") 89 expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 90 "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1 command") 91 subprocess.Popen.assert_called_with(expected_cmd, 92 shell=True, 93 stderr=-2, 94 stdin=None, 95 stdout=-1, 96 universal_newlines=True) 97 98 def testSshRunCmdwithExtraArgs(self): 99 """test ssh rum command with extra command.""" 100 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 101 ssh_object = ssh.Ssh(self.FAKE_IP, 102 self.FAKE_SSH_USER, 103 self.FAKE_SSH_PRIVATE_KEY_PATH, 104 self.FAKE_EXTRA_ARGS_SSH) 105 ssh_object.Run("command") 106 expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 107 "-o StrictHostKeyChecking=no " 108 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 109 "-l fake_user 1.1.1.1 command") 110 subprocess.Popen.assert_called_with(expected_cmd, 111 shell=True, 112 stderr=-2, 113 stdin=None, 114 stdout=-1, 115 universal_newlines=True) 116 117 def testScpPullFileCmd(self): 118 """Test scp pull file command.""" 119 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 120 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 121 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 122 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 123 "-o StrictHostKeyChecking=no fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 124 subprocess.Popen.assert_called_with(expected_cmd, 125 shell=True, 126 stderr=-2, 127 stdin=None, 128 stdout=-1, 129 universal_newlines=True) 130 131 def testScpPullFileCmdwithExtraArgs(self): 132 """Test scp pull file command.""" 133 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 134 ssh_object = ssh.Ssh(self.FAKE_IP, 135 self.FAKE_SSH_USER, 136 self.FAKE_SSH_PRIVATE_KEY_PATH, 137 self.FAKE_EXTRA_ARGS_SSH) 138 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 139 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 140 "-o StrictHostKeyChecking=no " 141 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 142 "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 143 subprocess.Popen.assert_called_with(expected_cmd, 144 shell=True, 145 stderr=-2, 146 stdin=None, 147 stdout=-1, 148 universal_newlines=True) 149 150 def testScpPushFileCmd(self): 151 """Test scp push file command.""" 152 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 153 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 154 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 155 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 156 "-o StrictHostKeyChecking=no /tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 157 subprocess.Popen.assert_called_with(expected_cmd, 158 shell=True, 159 stderr=-2, 160 stdin=None, 161 stdout=-1, 162 universal_newlines=True) 163 164 def testScpPushFileCmdwithExtraArgs(self): 165 """Test scp pull file command.""" 166 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 167 ssh_object = ssh.Ssh(self.FAKE_IP, 168 self.FAKE_SSH_USER, 169 self.FAKE_SSH_PRIVATE_KEY_PATH, 170 self.FAKE_EXTRA_ARGS_SSH) 171 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 172 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 173 "-o StrictHostKeyChecking=no " 174 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 175 "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 176 subprocess.Popen.assert_called_with(expected_cmd, 177 shell=True, 178 stderr=-2, 179 stdin=None, 180 stdout=-1, 181 universal_newlines=True) 182 183 # pylint: disable=protected-access 184 def testIPAddress(self): 185 """Test IP class to get ip address.""" 186 # Internal ip case. 187 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 188 user=self.FAKE_SSH_USER, 189 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 190 report_internal_ip=True) 191 expected_ip = "10.1.1.1" 192 self.assertEqual(ssh_object._ip, expected_ip) 193 194 # External ip case. 195 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 196 user=self.FAKE_SSH_USER, 197 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 198 expected_ip = "1.1.1.1" 199 self.assertEqual(ssh_object._ip, expected_ip) 200 201 # Only one ip case. 202 ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"), 203 user=self.FAKE_SSH_USER, 204 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 205 expected_ip = "1.1.1.1" 206 self.assertEqual(ssh_object._ip, expected_ip) 207 208 def testWaitForSsh(self): 209 """Test WaitForSsh.""" 210 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 211 user=self.FAKE_SSH_USER, 212 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 213 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 214 self.Patch(ssh, "_SshCall", return_value=-1) 215 self.assertRaises(errors.DeviceConnectionError, 216 ssh_object.WaitForSsh, 217 timeout=1, 218 max_retry=1) 219 220 def testSshCallWait(self): 221 """Test SshCallWait.""" 222 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 223 self.Patch(threading, "Timer") 224 fake_cmd = "fake command" 225 ssh._SshCallWait(fake_cmd) 226 threading.Timer.assert_not_called() 227 228 def testSshCallWaitTimeout(self): 229 """Test SshCallWait with timeout.""" 230 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 231 self.Patch(threading, "Timer") 232 fake_cmd = "fake command" 233 fake_timeout = 30 234 ssh._SshCallWait(fake_cmd, fake_timeout) 235 threading.Timer.assert_called_once() 236 237 def testSshCall(self): 238 """Test _SshCall.""" 239 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 240 self.Patch(threading, "Timer") 241 fake_cmd = "fake command" 242 ssh._SshCall(fake_cmd) 243 threading.Timer.assert_not_called() 244 245 def testSshCallTimeout(self): 246 """Test SshCallWait with timeout.""" 247 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 248 self.Patch(threading, "Timer") 249 fake_cmd = "fake command" 250 fake_timeout = 30 251 ssh._SshCall(fake_cmd, fake_timeout) 252 threading.Timer.assert_called_once() 253 254 def testSshLogOutput(self): 255 """Test _SshCall.""" 256 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 257 self.Patch(threading, "Timer") 258 fake_cmd = "fake command" 259 ssh._SshLogOutput(fake_cmd) 260 threading.Timer.assert_not_called() 261 262 def testSshLogOutputTimeout(self): 263 """Test SshCallWait with timeout.""" 264 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 265 self.Patch(threading, "Timer") 266 fake_cmd = "fake command" 267 fake_timeout = 30 268 ssh._SshLogOutput(fake_cmd, fake_timeout) 269 threading.Timer.assert_called_once() 270 271if __name__ == "__main__": 272 unittest.main() 273