1#!/usr/bin/env python 2# 3# Copyright 2016 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.internal.lib.utils.""" 17 18import errno 19import getpass 20import grp 21import os 22import shutil 23import subprocess 24import tempfile 25import time 26 27import unittest 28import mock 29import six 30 31from acloud import errors 32from acloud.internal.lib import driver_test_lib 33from acloud.internal.lib import utils 34 35 36# Tkinter may not be supported so mock it out. 37try: 38 import Tkinter 39except ImportError: 40 Tkinter = mock.Mock() 41 42 43class FakeTkinter(object): 44 """Fake implementation of Tkinter.Tk()""" 45 46 def __init__(self, width=None, height=None): 47 self.width = width 48 self.height = height 49 50 # pylint: disable=invalid-name 51 def winfo_screenheight(self): 52 """Return the screen height.""" 53 return self.height 54 55 # pylint: disable=invalid-name 56 def winfo_screenwidth(self): 57 """Return the screen width.""" 58 return self.width 59 60 61# pylint: disable=too-many-public-methods 62class UtilsTest(driver_test_lib.BaseDriverTest): 63 """Test Utils.""" 64 65 def TestTempDirSuccess(self): 66 """Test create a temp dir.""" 67 self.Patch(os, "chmod") 68 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 69 self.Patch(shutil, "rmtree") 70 with utils.TempDir(): 71 pass 72 # Verify. 73 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 74 shutil.rmtree.assert_called_with("/tmp/tempdir") # pylint: disable=no-member 75 76 def TestTempDirExceptionRaised(self): 77 """Test create a temp dir and exception is raised within with-clause.""" 78 self.Patch(os, "chmod") 79 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 80 self.Patch(shutil, "rmtree") 81 82 class ExpectedException(Exception): 83 """Expected exception.""" 84 85 def _Call(): 86 with utils.TempDir(): 87 raise ExpectedException("Expected exception.") 88 89 # Verify. ExpectedException should be raised. 90 self.assertRaises(ExpectedException, _Call) 91 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 92 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 93 94 def testTempDirWhenDeleteTempDirNoLongerExist(self): # pylint: disable=invalid-name 95 """Test create a temp dir and dir no longer exists during deletion.""" 96 self.Patch(os, "chmod") 97 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 98 expected_error = EnvironmentError() 99 expected_error.errno = errno.ENOENT 100 self.Patch(shutil, "rmtree", side_effect=expected_error) 101 102 def _Call(): 103 with utils.TempDir(): 104 pass 105 106 # Verify no exception should be raised when rmtree raises 107 # EnvironmentError with errno.ENOENT, i.e. 108 # directory no longer exists. 109 _Call() 110 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 111 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 112 113 def testTempDirWhenDeleteEncounterError(self): 114 """Test create a temp dir and encoutered error during deletion.""" 115 self.Patch(os, "chmod") 116 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 117 expected_error = OSError("Expected OS Error") 118 self.Patch(shutil, "rmtree", side_effect=expected_error) 119 120 def _Call(): 121 with utils.TempDir(): 122 pass 123 124 # Verify OSError should be raised. 125 self.assertRaises(OSError, _Call) 126 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 127 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 128 129 def testTempDirOrininalErrorRaised(self): 130 """Test original error is raised even if tmp dir deletion failed.""" 131 self.Patch(os, "chmod") 132 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 133 expected_error = OSError("Expected OS Error") 134 self.Patch(shutil, "rmtree", side_effect=expected_error) 135 136 class ExpectedException(Exception): 137 """Expected exception.""" 138 139 def _Call(): 140 with utils.TempDir(): 141 raise ExpectedException("Expected Exception") 142 143 # Verify. 144 # ExpectedException should be raised, and OSError 145 # should not be raised. 146 self.assertRaises(ExpectedException, _Call) 147 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 148 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 149 150 def testCreateSshKeyPairKeyAlreadyExists(self): #pylint: disable=invalid-name 151 """Test when the key pair already exists.""" 152 public_key = "/fake/public_key" 153 private_key = "/fake/private_key" 154 self.Patch(os.path, "exists", side_effect=[True, True]) 155 self.Patch(subprocess, "check_call") 156 self.Patch(os, "makedirs", return_value=True) 157 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 158 self.assertEqual(subprocess.check_call.call_count, 0) #pylint: disable=no-member 159 160 def testCreateSshKeyPairKeyAreCreated(self): 161 """Test when the key pair created.""" 162 public_key = "/fake/public_key" 163 private_key = "/fake/private_key" 164 self.Patch(os.path, "exists", return_value=False) 165 self.Patch(os, "makedirs", return_value=True) 166 self.Patch(subprocess, "check_call") 167 self.Patch(os, "rename") 168 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 169 self.assertEqual(subprocess.check_call.call_count, 1) #pylint: disable=no-member 170 subprocess.check_call.assert_called_with( #pylint: disable=no-member 171 utils.SSH_KEYGEN_CMD + 172 ["-C", getpass.getuser(), "-f", private_key], 173 stdout=mock.ANY, 174 stderr=mock.ANY) 175 176 def testCreatePublicKeyAreCreated(self): 177 """Test when the PublicKey created.""" 178 public_key = "/fake/public_key" 179 private_key = "/fake/private_key" 180 self.Patch(os.path, "exists", side_effect=[False, True, True]) 181 self.Patch(os, "makedirs", return_value=True) 182 mock_open = mock.mock_open(read_data=public_key) 183 self.Patch(subprocess, "check_output") 184 self.Patch(os, "rename") 185 with mock.patch.object(six.moves.builtins, "open", mock_open): 186 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 187 self.assertEqual(subprocess.check_output.call_count, 1) #pylint: disable=no-member 188 subprocess.check_output.assert_called_with( #pylint: disable=no-member 189 utils.SSH_KEYGEN_PUB_CMD +["-f", private_key]) 190 191 def TestRetryOnException(self): 192 """Test Retry.""" 193 194 def _IsValueError(exc): 195 return isinstance(exc, ValueError) 196 197 num_retry = 5 198 199 @utils.RetryOnException(_IsValueError, num_retry) 200 def _RaiseAndRetry(sentinel): 201 sentinel.alert() 202 raise ValueError("Fake error.") 203 204 sentinel = mock.MagicMock() 205 self.assertRaises(ValueError, _RaiseAndRetry, sentinel) 206 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 207 208 def testRetryExceptionType(self): 209 """Test RetryExceptionType function.""" 210 211 def _RaiseAndRetry(sentinel): 212 sentinel.alert() 213 raise ValueError("Fake error.") 214 215 num_retry = 5 216 sentinel = mock.MagicMock() 217 self.assertRaises( 218 ValueError, 219 utils.RetryExceptionType, (KeyError, ValueError), 220 num_retry, 221 _RaiseAndRetry, 222 0, # sleep_multiplier 223 1, # retry_backoff_factor 224 sentinel=sentinel) 225 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 226 227 def testRetry(self): 228 """Test Retry.""" 229 mock_sleep = self.Patch(time, "sleep") 230 231 def _RaiseAndRetry(sentinel): 232 sentinel.alert() 233 raise ValueError("Fake error.") 234 235 num_retry = 5 236 sentinel = mock.MagicMock() 237 self.assertRaises( 238 ValueError, 239 utils.RetryExceptionType, (ValueError, KeyError), 240 num_retry, 241 _RaiseAndRetry, 242 1, # sleep_multiplier 243 2, # retry_backoff_factor 244 sentinel=sentinel) 245 246 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 247 mock_sleep.assert_has_calls( 248 [ 249 mock.call(1), 250 mock.call(2), 251 mock.call(4), 252 mock.call(8), 253 mock.call(16) 254 ]) 255 256 @mock.patch.object(six.moves, "input") 257 def testGetAnswerFromList(self, mock_raw_input): 258 """Test GetAnswerFromList.""" 259 answer_list = ["image1.zip", "image2.zip", "image3.zip"] 260 mock_raw_input.return_value = 0 261 with self.assertRaises(SystemExit): 262 utils.GetAnswerFromList(answer_list) 263 mock_raw_input.side_effect = [1, 2, 3, 4] 264 self.assertEqual(utils.GetAnswerFromList(answer_list), 265 ["image1.zip"]) 266 self.assertEqual(utils.GetAnswerFromList(answer_list), 267 ["image2.zip"]) 268 self.assertEqual(utils.GetAnswerFromList(answer_list), 269 ["image3.zip"]) 270 self.assertEqual(utils.GetAnswerFromList(answer_list, 271 enable_choose_all=True), 272 answer_list) 273 274 @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.") 275 @mock.patch.object(Tkinter, "Tk") 276 def testCalculateVNCScreenRatio(self, mock_tk): 277 """Test Calculating the scale ratio of VNC display.""" 278 # Get scale-down ratio if screen height is smaller than AVD height. 279 mock_tk.return_value = FakeTkinter(height=800, width=1200) 280 avd_h = 1920 281 avd_w = 1080 282 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4) 283 284 # Get scale-down ratio if screen width is smaller than AVD width. 285 mock_tk.return_value = FakeTkinter(height=800, width=1200) 286 avd_h = 900 287 avd_w = 1920 288 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 289 290 # Scale ratio = 1 if screen is larger than AVD. 291 mock_tk.return_value = FakeTkinter(height=1080, width=1920) 292 avd_h = 800 293 avd_w = 1280 294 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1) 295 296 # Get the scale if ratio of width is smaller than the 297 # ratio of height. 298 mock_tk.return_value = FakeTkinter(height=1200, width=800) 299 avd_h = 1920 300 avd_w = 1080 301 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 302 303 # pylint: disable=protected-access 304 def testCheckUserInGroups(self): 305 """Test CheckUserInGroups.""" 306 self.Patch(os, "getgroups", return_value=[1, 2, 3]) 307 gr1 = mock.MagicMock() 308 gr1.gr_name = "fake_gr_1" 309 gr2 = mock.MagicMock() 310 gr2.gr_name = "fake_gr_2" 311 gr3 = mock.MagicMock() 312 gr3.gr_name = "fake_gr_3" 313 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 314 315 # User in all required groups should return true. 316 self.assertTrue( 317 utils.CheckUserInGroups( 318 ["fake_gr_1", "fake_gr_2"])) 319 320 # User not in all required groups should return False. 321 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 322 self.assertFalse( 323 utils.CheckUserInGroups( 324 ["fake_gr_1", "fake_gr_4"])) 325 326 @mock.patch.object(utils, "CheckUserInGroups") 327 def testAddUserGroupsToCmd(self, mock_user_group): 328 """Test AddUserGroupsToCmd.""" 329 command = "test_command" 330 groups = ["group1", "group2"] 331 # Don't add user group in command 332 mock_user_group.return_value = True 333 expected_value = "test_command" 334 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 335 groups)) 336 337 # Add user group in command 338 mock_user_group.return_value = False 339 expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF" 340 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 341 groups)) 342 343 # pylint: disable=invalid-name 344 def testTimeoutException(self): 345 """Test TimeoutException.""" 346 @utils.TimeoutException(1, "should time out") 347 def functionThatWillTimeOut(): 348 """Test decorator of @utils.TimeoutException should timeout.""" 349 time.sleep(5) 350 351 self.assertRaises(errors.FunctionTimeoutError, 352 functionThatWillTimeOut) 353 354 355 def testTimeoutExceptionNoTimeout(self): 356 """Test No TimeoutException.""" 357 @utils.TimeoutException(5, "shouldn't time out") 358 def functionThatShouldNotTimeout(): 359 """Test decorator of @utils.TimeoutException shouldn't timeout.""" 360 return None 361 try: 362 functionThatShouldNotTimeout() 363 except errors.FunctionTimeoutError: 364 self.fail("shouldn't timeout") 365 366 def testAutoConnectCreateSSHTunnelFail(self): 367 """Test auto connect.""" 368 fake_ip_addr = "1.1.1.1" 369 fake_rsa_key_file = "/tmp/rsa_file" 370 fake_target_vnc_port = 8888 371 target_adb_port = 9999 372 ssh_user = "fake_user" 373 call_side_effect = subprocess.CalledProcessError(123, "fake", 374 "fake error") 375 result = utils.ForwardedPorts(vnc_port=None, adb_port=None) 376 self.Patch(subprocess, "check_call", side_effect=call_side_effect) 377 self.assertEqual(result, utils.AutoConnect(fake_ip_addr, 378 fake_rsa_key_file, 379 fake_target_vnc_port, 380 target_adb_port, 381 ssh_user)) 382 383 # pylint: disable=protected-access,no-member 384 def testExtraArgsSSHTunnel(self): 385 """Tesg extra args will be the same with expanded args.""" 386 fake_ip_addr = "1.1.1.1" 387 fake_rsa_key_file = "/tmp/rsa_file" 388 fake_target_vnc_port = 8888 389 target_adb_port = 9999 390 ssh_user = "fake_user" 391 fake_port = 12345 392 self.Patch(utils, "PickFreePort", return_value=fake_port) 393 self.Patch(utils, "_ExecuteCommand") 394 self.Patch(subprocess, "check_call", return_value=True) 395 extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'" 396 utils.AutoConnect(ip_addr=fake_ip_addr, 397 rsa_key_file=fake_rsa_key_file, 398 target_vnc_port=fake_target_vnc_port, 399 target_adb_port=target_adb_port, 400 ssh_user=ssh_user, 401 client_adb_port=fake_port, 402 extra_args_ssh_tunnel=extra_args_ssh_tunnel) 403 args_list = ["-i", "/tmp/rsa_file", 404 "-o", "UserKnownHostsFile=/dev/null", 405 "-o", "StrictHostKeyChecking=no", 406 "-L", "12345:127.0.0.1:8888", 407 "-L", "12345:127.0.0.1:9999", 408 "-N", "-f", "-l", "fake_user", "1.1.1.1", 409 "-o", "command=shell %s %h", 410 "-o", "command1=ls -la"] 411 first_call_args = utils._ExecuteCommand.call_args_list[0][0] 412 self.assertEqual(first_call_args[1], args_list) 413 414 # pylint: disable=protected-access, no-member 415 def testCleanupSSVncviwer(self): 416 """test cleanup ssvnc viewer.""" 417 fake_vnc_port = 9999 418 fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % { 419 "vnc_port": fake_vnc_port} 420 self.Patch(utils, "IsCommandRunning", return_value=True) 421 self.Patch(subprocess, "check_call", return_value=True) 422 utils.CleanupSSVncviewer(fake_vnc_port) 423 subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern]) 424 425 subprocess.check_call.call_count = 0 426 self.Patch(utils, "IsCommandRunning", return_value=False) 427 utils.CleanupSSVncviewer(fake_vnc_port) 428 subprocess.check_call.assert_not_called() 429 430 431if __name__ == "__main__": 432 unittest.main() 433