1#!/usr/bin/env python
2#
3# Copyright 2018 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Tests for acloud.public.actions.common_operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20
21import shlex
22import unittest
23
24from unittest import mock
25
26from acloud import errors
27from acloud.internal.lib import android_build_client
28from acloud.internal.lib import android_compute_client
29from acloud.internal.lib import auth
30from acloud.internal.lib import driver_test_lib
31from acloud.internal.lib import utils
32from acloud.internal.lib import ssh
33from acloud.public import report
34from acloud.public.actions import common_operations
35
36
37class CommonOperationsTest(driver_test_lib.BaseDriverTest):
38    """Test Common Operations."""
39    IP = ssh.IP(external="127.0.0.1", internal="10.0.0.1")
40    INSTANCE = "fake-instance"
41    CMD = "test-cmd"
42    AVD_TYPE = "fake-type"
43    BRANCH = "fake-branch"
44    BUILD_TARGET = "fake-target"
45    BUILD_ID = "fake-build-id"
46
47    # pylint: disable=protected-access
48    def setUp(self):
49        """Set up the test."""
50        super().setUp()
51        self.build_client = mock.MagicMock()
52        self.device_factory = mock.MagicMock()
53        self.Patch(
54            android_build_client,
55            "AndroidBuildClient",
56            return_value=self.build_client)
57        self.compute_client = mock.MagicMock()
58        self.Patch(
59            android_compute_client,
60            "AndroidComputeClient",
61            return_value=self.compute_client)
62        self.Patch(auth, "CreateCredentials", return_value=mock.MagicMock())
63        self.Patch(self.compute_client, "GetInstanceIP", return_value=self.IP)
64        self.Patch(
65            self.device_factory, "CreateInstance", return_value=self.INSTANCE)
66        self.Patch(
67            self.device_factory,
68            "GetComputeClient",
69            return_value=self.compute_client)
70        self.Patch(self.device_factory, "GetBuildInfoDict",
71                   return_value={"branch": self.BRANCH,
72                                 "build_id": self.BUILD_ID,
73                                 "build_target": self.BUILD_TARGET,
74                                 "gcs_bucket_build_id": self.BUILD_ID})
75        self.Patch(self.device_factory, "GetBuildInfoDict",
76                   return_value={"branch": self.BRANCH,
77                                 "build_id": self.BUILD_ID,
78                                 "build_target": self.BUILD_TARGET,
79                                 "gcs_bucket_build_id": self.BUILD_ID})
80
81    @staticmethod
82    def _CreateCfg():
83        """A helper method that creates a mock configuration object."""
84        cfg = mock.MagicMock()
85        cfg.service_account_name = "fake@service.com"
86        cfg.service_account_private_key_path = "/fake/path/to/key"
87        cfg.zone = "fake_zone"
88        cfg.disk_image_name = "fake_image.tar.gz"
89        cfg.disk_image_mime_type = "fake/type"
90        cfg.ssh_private_key_path = ""
91        cfg.ssh_public_key_path = ""
92        return cfg
93
94    def testDevicePoolCreateDevices(self):
95        """Test Device Pool Create Devices."""
96        pool = common_operations.DevicePool(self.device_factory)
97        pool.CreateDevices(5)
98        self.assertEqual(self.device_factory.CreateInstance.call_count, 5)
99        self.assertEqual(len(pool.devices), 5)
100
101    def testCreateDevices(self):
102        """Test Create Devices."""
103        cfg = self._CreateCfg()
104        _report = common_operations.CreateDevices(self.CMD, cfg,
105                                                  self.device_factory, 1,
106                                                  self.AVD_TYPE)
107        self.assertEqual(_report.command, self.CMD)
108        self.assertEqual(_report.status, report.Status.SUCCESS)
109        self.assertEqual(
110            _report.data,
111            {"devices": [{
112                "ip": self.IP.external,
113                "instance_name": self.INSTANCE,
114                "branch": self.BRANCH,
115                "build_id": self.BUILD_ID,
116                "build_target": self.BUILD_TARGET,
117                "gcs_bucket_build_id": self.BUILD_ID,
118            }]})
119
120    def testCreateDevicesWithAdbPort(self):
121        """Test Create Devices with adb port for cuttlefish avd type."""
122        self.Patch(utils, "_ExecuteCommand")
123        self.Patch(utils, "PickFreePort", return_value=56789)
124        self.Patch(shlex, "split", return_value=[])
125        cfg = self._CreateCfg()
126        _report = common_operations.CreateDevices(self.CMD, cfg,
127                                                  self.device_factory, 1,
128                                                  "cuttlefish",
129                                                  autoconnect=True,
130                                                  client_adb_port=12345)
131        self.assertEqual(_report.command, self.CMD)
132        self.assertEqual(_report.status, report.Status.SUCCESS)
133        self.assertEqual(
134            _report.data,
135            {"devices": [{
136                "ip": self.IP.external,
137                "instance_name": self.INSTANCE,
138                "branch": self.BRANCH,
139                "build_id": self.BUILD_ID,
140                "adb_port": 12345,
141                "device_serial": "127.0.0.1:12345",
142                "vnc_port": 56789,
143                "build_target": self.BUILD_TARGET,
144                "gcs_bucket_build_id": self.BUILD_ID,
145            }]})
146
147    def testCreateDevicesInternalIP(self):
148        """Test Create Devices and report internal IP."""
149        cfg = self._CreateCfg()
150        _report = common_operations.CreateDevices(self.CMD, cfg,
151                                                  self.device_factory, 1,
152                                                  self.AVD_TYPE,
153                                                  report_internal_ip=True)
154        self.assertEqual(_report.command, self.CMD)
155        self.assertEqual(_report.status, report.Status.SUCCESS)
156        self.assertEqual(
157            _report.data,
158            {"devices": [{
159                "ip": self.IP.internal,
160                "instance_name": self.INSTANCE,
161                "branch": self.BRANCH,
162                "build_id": self.BUILD_ID,
163                "build_target": self.BUILD_TARGET,
164                "gcs_bucket_build_id": self.BUILD_ID,
165            }]})
166
167    def testGetErrorType(self):
168        """Test GetErrorType."""
169        # Test with CheckGCEZonesQuotaError()
170        error = errors.CheckGCEZonesQuotaError()
171        expected_result = common_operations._GCE_QUOTA_ERROR
172        self.assertEqual(common_operations._GetErrorType(error), expected_result)
173
174        # Test with DownloadArtifactError()
175        error = errors.DownloadArtifactError()
176        expected_result = common_operations._ACLOUD_DOWNLOAD_ARTIFACT_ERROR
177        self.assertEqual(common_operations._GetErrorType(error), expected_result)
178
179        # Test with DeviceConnectionError()
180        error = errors.DeviceConnectionError()
181        expected_result = common_operations._ACLOUD_SSH_CONNECT_ERROR
182        self.assertEqual(common_operations._GetErrorType(error), expected_result)
183
184        # Test with ACLOUD_GENERIC_ERROR
185        error = errors.DriverError()
186        expected_result = common_operations._ACLOUD_GENERIC_ERROR
187        self.assertEqual(common_operations._GetErrorType(error), expected_result)
188
189        # Test with error message about GCE quota issue
190        error = errors.DriverError("Quota exceeded for quota read group.")
191        expected_result = common_operations._GCE_QUOTA_ERROR
192        self.assertEqual(common_operations._GetErrorType(error), expected_result)
193
194
195if __name__ == "__main__":
196    unittest.main()
197