1#!/usr/bin/env python 2# 3# Copyright 2018 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.public.actions.common_operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20 21import shlex 22import unittest 23 24from unittest import mock 25 26from acloud import errors 27from acloud.internal.lib import android_build_client 28from acloud.internal.lib import android_compute_client 29from acloud.internal.lib import auth 30from acloud.internal.lib import driver_test_lib 31from acloud.internal.lib import utils 32from acloud.internal.lib import ssh 33from acloud.public import report 34from acloud.public.actions import common_operations 35 36 37class CommonOperationsTest(driver_test_lib.BaseDriverTest): 38 """Test Common Operations.""" 39 IP = ssh.IP(external="127.0.0.1", internal="10.0.0.1") 40 INSTANCE = "fake-instance" 41 CMD = "test-cmd" 42 AVD_TYPE = "fake-type" 43 BRANCH = "fake-branch" 44 BUILD_TARGET = "fake-target" 45 BUILD_ID = "fake-build-id" 46 47 # pylint: disable=protected-access 48 def setUp(self): 49 """Set up the test.""" 50 super().setUp() 51 self.build_client = mock.MagicMock() 52 self.device_factory = mock.MagicMock() 53 self.Patch( 54 android_build_client, 55 "AndroidBuildClient", 56 return_value=self.build_client) 57 self.compute_client = mock.MagicMock() 58 self.Patch( 59 android_compute_client, 60 "AndroidComputeClient", 61 return_value=self.compute_client) 62 self.Patch(auth, "CreateCredentials", return_value=mock.MagicMock()) 63 self.Patch(self.compute_client, "GetInstanceIP", return_value=self.IP) 64 self.Patch( 65 self.device_factory, "CreateInstance", return_value=self.INSTANCE) 66 self.Patch( 67 self.device_factory, 68 "GetComputeClient", 69 return_value=self.compute_client) 70 self.Patch(self.device_factory, "GetBuildInfoDict", 71 return_value={"branch": self.BRANCH, 72 "build_id": self.BUILD_ID, 73 "build_target": self.BUILD_TARGET, 74 "gcs_bucket_build_id": self.BUILD_ID}) 75 self.Patch(self.device_factory, "GetBuildInfoDict", 76 return_value={"branch": self.BRANCH, 77 "build_id": self.BUILD_ID, 78 "build_target": self.BUILD_TARGET, 79 "gcs_bucket_build_id": self.BUILD_ID}) 80 81 @staticmethod 82 def _CreateCfg(): 83 """A helper method that creates a mock configuration object.""" 84 cfg = mock.MagicMock() 85 cfg.service_account_name = "fake@service.com" 86 cfg.service_account_private_key_path = "/fake/path/to/key" 87 cfg.zone = "fake_zone" 88 cfg.disk_image_name = "fake_image.tar.gz" 89 cfg.disk_image_mime_type = "fake/type" 90 cfg.ssh_private_key_path = "" 91 cfg.ssh_public_key_path = "" 92 return cfg 93 94 def testDevicePoolCreateDevices(self): 95 """Test Device Pool Create Devices.""" 96 pool = common_operations.DevicePool(self.device_factory) 97 pool.CreateDevices(5) 98 self.assertEqual(self.device_factory.CreateInstance.call_count, 5) 99 self.assertEqual(len(pool.devices), 5) 100 101 def testCreateDevices(self): 102 """Test Create Devices.""" 103 cfg = self._CreateCfg() 104 _report = common_operations.CreateDevices(self.CMD, cfg, 105 self.device_factory, 1, 106 self.AVD_TYPE) 107 self.assertEqual(_report.command, self.CMD) 108 self.assertEqual(_report.status, report.Status.SUCCESS) 109 self.assertEqual( 110 _report.data, 111 {"devices": [{ 112 "ip": self.IP.external, 113 "instance_name": self.INSTANCE, 114 "branch": self.BRANCH, 115 "build_id": self.BUILD_ID, 116 "build_target": self.BUILD_TARGET, 117 "gcs_bucket_build_id": self.BUILD_ID, 118 }]}) 119 120 def testCreateDevicesWithAdbPort(self): 121 """Test Create Devices with adb port for cuttlefish avd type.""" 122 self.Patch(utils, "_ExecuteCommand") 123 self.Patch(utils, "PickFreePort", return_value=56789) 124 self.Patch(shlex, "split", return_value=[]) 125 cfg = self._CreateCfg() 126 _report = common_operations.CreateDevices(self.CMD, cfg, 127 self.device_factory, 1, 128 "cuttlefish", 129 autoconnect=True, 130 client_adb_port=12345) 131 self.assertEqual(_report.command, self.CMD) 132 self.assertEqual(_report.status, report.Status.SUCCESS) 133 self.assertEqual( 134 _report.data, 135 {"devices": [{ 136 "ip": self.IP.external, 137 "instance_name": self.INSTANCE, 138 "branch": self.BRANCH, 139 "build_id": self.BUILD_ID, 140 "adb_port": 12345, 141 "device_serial": "127.0.0.1:12345", 142 "vnc_port": 56789, 143 "build_target": self.BUILD_TARGET, 144 "gcs_bucket_build_id": self.BUILD_ID, 145 }]}) 146 147 def testCreateDevicesInternalIP(self): 148 """Test Create Devices and report internal IP.""" 149 cfg = self._CreateCfg() 150 _report = common_operations.CreateDevices(self.CMD, cfg, 151 self.device_factory, 1, 152 self.AVD_TYPE, 153 report_internal_ip=True) 154 self.assertEqual(_report.command, self.CMD) 155 self.assertEqual(_report.status, report.Status.SUCCESS) 156 self.assertEqual( 157 _report.data, 158 {"devices": [{ 159 "ip": self.IP.internal, 160 "instance_name": self.INSTANCE, 161 "branch": self.BRANCH, 162 "build_id": self.BUILD_ID, 163 "build_target": self.BUILD_TARGET, 164 "gcs_bucket_build_id": self.BUILD_ID, 165 }]}) 166 167 def testGetErrorType(self): 168 """Test GetErrorType.""" 169 # Test with CheckGCEZonesQuotaError() 170 error = errors.CheckGCEZonesQuotaError() 171 expected_result = common_operations._GCE_QUOTA_ERROR 172 self.assertEqual(common_operations._GetErrorType(error), expected_result) 173 174 # Test with DownloadArtifactError() 175 error = errors.DownloadArtifactError() 176 expected_result = common_operations._ACLOUD_DOWNLOAD_ARTIFACT_ERROR 177 self.assertEqual(common_operations._GetErrorType(error), expected_result) 178 179 # Test with DeviceConnectionError() 180 error = errors.DeviceConnectionError() 181 expected_result = common_operations._ACLOUD_SSH_CONNECT_ERROR 182 self.assertEqual(common_operations._GetErrorType(error), expected_result) 183 184 # Test with ACLOUD_GENERIC_ERROR 185 error = errors.DriverError() 186 expected_result = common_operations._ACLOUD_GENERIC_ERROR 187 self.assertEqual(common_operations._GetErrorType(error), expected_result) 188 189 # Test with error message about GCE quota issue 190 error = errors.DriverError("Quota exceeded for quota read group.") 191 expected_result = common_operations._GCE_QUOTA_ERROR 192 self.assertEqual(common_operations._GetErrorType(error), expected_result) 193 194 195if __name__ == "__main__": 196 unittest.main() 197