1#!/usr/bin/env python
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Tests for acloud.internal.lib.utils."""
17
18import errno
19import getpass
20import grp
21import os
22import shutil
23import subprocess
24import tempfile
25import time
26
27import unittest
28import mock
29import six
30
31from acloud import errors
32from acloud.internal.lib import driver_test_lib
33from acloud.internal.lib import utils
34
35
36# Tkinter may not be supported so mock it out.
37try:
38    import Tkinter
39except ImportError:
40    Tkinter = mock.Mock()
41
42
43class FakeTkinter(object):
44    """Fake implementation of Tkinter.Tk()"""
45
46    def __init__(self, width=None, height=None):
47        self.width = width
48        self.height = height
49
50    # pylint: disable=invalid-name
51    def winfo_screenheight(self):
52        """Return the screen height."""
53        return self.height
54
55    # pylint: disable=invalid-name
56    def winfo_screenwidth(self):
57        """Return the screen width."""
58        return self.width
59
60
61# pylint: disable=too-many-public-methods
62class UtilsTest(driver_test_lib.BaseDriverTest):
63    """Test Utils."""
64
65    def TestTempDirSuccess(self):
66        """Test create a temp dir."""
67        self.Patch(os, "chmod")
68        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
69        self.Patch(shutil, "rmtree")
70        with utils.TempDir():
71            pass
72        # Verify.
73        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
74        shutil.rmtree.assert_called_with("/tmp/tempdir")  # pylint: disable=no-member
75
76    def TestTempDirExceptionRaised(self):
77        """Test create a temp dir and exception is raised within with-clause."""
78        self.Patch(os, "chmod")
79        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
80        self.Patch(shutil, "rmtree")
81
82        class ExpectedException(Exception):
83            """Expected exception."""
84
85        def _Call():
86            with utils.TempDir():
87                raise ExpectedException("Expected exception.")
88
89        # Verify. ExpectedException should be raised.
90        self.assertRaises(ExpectedException, _Call)
91        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
92        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
93
94    def testTempDirWhenDeleteTempDirNoLongerExist(self):  # pylint: disable=invalid-name
95        """Test create a temp dir and dir no longer exists during deletion."""
96        self.Patch(os, "chmod")
97        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
98        expected_error = EnvironmentError()
99        expected_error.errno = errno.ENOENT
100        self.Patch(shutil, "rmtree", side_effect=expected_error)
101
102        def _Call():
103            with utils.TempDir():
104                pass
105
106        # Verify no exception should be raised when rmtree raises
107        # EnvironmentError with errno.ENOENT, i.e.
108        # directory no longer exists.
109        _Call()
110        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
111        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
112
113    def testTempDirWhenDeleteEncounterError(self):
114        """Test create a temp dir and encoutered error during deletion."""
115        self.Patch(os, "chmod")
116        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
117        expected_error = OSError("Expected OS Error")
118        self.Patch(shutil, "rmtree", side_effect=expected_error)
119
120        def _Call():
121            with utils.TempDir():
122                pass
123
124        # Verify OSError should be raised.
125        self.assertRaises(OSError, _Call)
126        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
127        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
128
129    def testTempDirOrininalErrorRaised(self):
130        """Test original error is raised even if tmp dir deletion failed."""
131        self.Patch(os, "chmod")
132        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
133        expected_error = OSError("Expected OS Error")
134        self.Patch(shutil, "rmtree", side_effect=expected_error)
135
136        class ExpectedException(Exception):
137            """Expected exception."""
138
139        def _Call():
140            with utils.TempDir():
141                raise ExpectedException("Expected Exception")
142
143        # Verify.
144        # ExpectedException should be raised, and OSError
145        # should not be raised.
146        self.assertRaises(ExpectedException, _Call)
147        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
148        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
149
150    def testCreateSshKeyPairKeyAlreadyExists(self):  #pylint: disable=invalid-name
151        """Test when the key pair already exists."""
152        public_key = "/fake/public_key"
153        private_key = "/fake/private_key"
154        self.Patch(os.path, "exists", side_effect=[True, True])
155        self.Patch(subprocess, "check_call")
156        self.Patch(os, "makedirs", return_value=True)
157        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
158        self.assertEqual(subprocess.check_call.call_count, 0)  #pylint: disable=no-member
159
160    def testCreateSshKeyPairKeyAreCreated(self):
161        """Test when the key pair created."""
162        public_key = "/fake/public_key"
163        private_key = "/fake/private_key"
164        self.Patch(os.path, "exists", return_value=False)
165        self.Patch(os, "makedirs", return_value=True)
166        self.Patch(subprocess, "check_call")
167        self.Patch(os, "rename")
168        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
169        self.assertEqual(subprocess.check_call.call_count, 1)  #pylint: disable=no-member
170        subprocess.check_call.assert_called_with(  #pylint: disable=no-member
171            utils.SSH_KEYGEN_CMD +
172            ["-C", getpass.getuser(), "-f", private_key],
173            stdout=mock.ANY,
174            stderr=mock.ANY)
175
176    def testCreatePublicKeyAreCreated(self):
177        """Test when the PublicKey created."""
178        public_key = "/fake/public_key"
179        private_key = "/fake/private_key"
180        self.Patch(os.path, "exists", side_effect=[False, True, True])
181        self.Patch(os, "makedirs", return_value=True)
182        mock_open = mock.mock_open(read_data=public_key)
183        self.Patch(subprocess, "check_output")
184        self.Patch(os, "rename")
185        with mock.patch.object(six.moves.builtins, "open", mock_open):
186            utils.CreateSshKeyPairIfNotExist(private_key, public_key)
187        self.assertEqual(subprocess.check_output.call_count, 1)  #pylint: disable=no-member
188        subprocess.check_output.assert_called_with(  #pylint: disable=no-member
189            utils.SSH_KEYGEN_PUB_CMD +["-f", private_key])
190
191    def TestRetryOnException(self):
192        """Test Retry."""
193
194        def _IsValueError(exc):
195            return isinstance(exc, ValueError)
196
197        num_retry = 5
198
199        @utils.RetryOnException(_IsValueError, num_retry)
200        def _RaiseAndRetry(sentinel):
201            sentinel.alert()
202            raise ValueError("Fake error.")
203
204        sentinel = mock.MagicMock()
205        self.assertRaises(ValueError, _RaiseAndRetry, sentinel)
206        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
207
208    def testRetryExceptionType(self):
209        """Test RetryExceptionType function."""
210
211        def _RaiseAndRetry(sentinel):
212            sentinel.alert()
213            raise ValueError("Fake error.")
214
215        num_retry = 5
216        sentinel = mock.MagicMock()
217        self.assertRaises(
218            ValueError,
219            utils.RetryExceptionType, (KeyError, ValueError),
220            num_retry,
221            _RaiseAndRetry,
222            0, # sleep_multiplier
223            1, # retry_backoff_factor
224            sentinel=sentinel)
225        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
226
227    def testRetry(self):
228        """Test Retry."""
229        mock_sleep = self.Patch(time, "sleep")
230
231        def _RaiseAndRetry(sentinel):
232            sentinel.alert()
233            raise ValueError("Fake error.")
234
235        num_retry = 5
236        sentinel = mock.MagicMock()
237        self.assertRaises(
238            ValueError,
239            utils.RetryExceptionType, (ValueError, KeyError),
240            num_retry,
241            _RaiseAndRetry,
242            1, # sleep_multiplier
243            2, # retry_backoff_factor
244            sentinel=sentinel)
245
246        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
247        mock_sleep.assert_has_calls(
248            [
249                mock.call(1),
250                mock.call(2),
251                mock.call(4),
252                mock.call(8),
253                mock.call(16)
254            ])
255
256    @mock.patch.object(six.moves, "input")
257    def testGetAnswerFromList(self, mock_raw_input):
258        """Test GetAnswerFromList."""
259        answer_list = ["image1.zip", "image2.zip", "image3.zip"]
260        mock_raw_input.return_value = 0
261        with self.assertRaises(SystemExit):
262            utils.GetAnswerFromList(answer_list)
263        mock_raw_input.side_effect = [1, 2, 3, 4]
264        self.assertEqual(utils.GetAnswerFromList(answer_list),
265                         ["image1.zip"])
266        self.assertEqual(utils.GetAnswerFromList(answer_list),
267                         ["image2.zip"])
268        self.assertEqual(utils.GetAnswerFromList(answer_list),
269                         ["image3.zip"])
270        self.assertEqual(utils.GetAnswerFromList(answer_list,
271                                                 enable_choose_all=True),
272                         answer_list)
273
274    @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.")
275    @mock.patch.object(Tkinter, "Tk")
276    def testCalculateVNCScreenRatio(self, mock_tk):
277        """Test Calculating the scale ratio of VNC display."""
278        # Get scale-down ratio if screen height is smaller than AVD height.
279        mock_tk.return_value = FakeTkinter(height=800, width=1200)
280        avd_h = 1920
281        avd_w = 1080
282        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4)
283
284        # Get scale-down ratio if screen width is smaller than AVD width.
285        mock_tk.return_value = FakeTkinter(height=800, width=1200)
286        avd_h = 900
287        avd_w = 1920
288        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
289
290        # Scale ratio = 1 if screen is larger than AVD.
291        mock_tk.return_value = FakeTkinter(height=1080, width=1920)
292        avd_h = 800
293        avd_w = 1280
294        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1)
295
296        # Get the scale if ratio of width is smaller than the
297        # ratio of height.
298        mock_tk.return_value = FakeTkinter(height=1200, width=800)
299        avd_h = 1920
300        avd_w = 1080
301        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
302
303    # pylint: disable=protected-access
304    def testCheckUserInGroups(self):
305        """Test CheckUserInGroups."""
306        self.Patch(os, "getgroups", return_value=[1, 2, 3])
307        gr1 = mock.MagicMock()
308        gr1.gr_name = "fake_gr_1"
309        gr2 = mock.MagicMock()
310        gr2.gr_name = "fake_gr_2"
311        gr3 = mock.MagicMock()
312        gr3.gr_name = "fake_gr_3"
313        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
314
315        # User in all required groups should return true.
316        self.assertTrue(
317            utils.CheckUserInGroups(
318                ["fake_gr_1", "fake_gr_2"]))
319
320        # User not in all required groups should return False.
321        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
322        self.assertFalse(
323            utils.CheckUserInGroups(
324                ["fake_gr_1", "fake_gr_4"]))
325
326    @mock.patch.object(utils, "CheckUserInGroups")
327    def testAddUserGroupsToCmd(self, mock_user_group):
328        """Test AddUserGroupsToCmd."""
329        command = "test_command"
330        groups = ["group1", "group2"]
331        # Don't add user group in command
332        mock_user_group.return_value = True
333        expected_value = "test_command"
334        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
335                                                                  groups))
336
337        # Add user group in command
338        mock_user_group.return_value = False
339        expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF"
340        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
341                                                                  groups))
342
343    # pylint: disable=invalid-name
344    def testTimeoutException(self):
345        """Test TimeoutException."""
346        @utils.TimeoutException(1, "should time out")
347        def functionThatWillTimeOut():
348            """Test decorator of @utils.TimeoutException should timeout."""
349            time.sleep(5)
350
351        self.assertRaises(errors.FunctionTimeoutError,
352                          functionThatWillTimeOut)
353
354
355    def testTimeoutExceptionNoTimeout(self):
356        """Test No TimeoutException."""
357        @utils.TimeoutException(5, "shouldn't time out")
358        def functionThatShouldNotTimeout():
359            """Test decorator of @utils.TimeoutException shouldn't timeout."""
360            return None
361        try:
362            functionThatShouldNotTimeout()
363        except errors.FunctionTimeoutError:
364            self.fail("shouldn't timeout")
365
366    def testAutoConnectCreateSSHTunnelFail(self):
367        """Test auto connect."""
368        fake_ip_addr = "1.1.1.1"
369        fake_rsa_key_file = "/tmp/rsa_file"
370        fake_target_vnc_port = 8888
371        target_adb_port = 9999
372        ssh_user = "fake_user"
373        call_side_effect = subprocess.CalledProcessError(123, "fake",
374                                                         "fake error")
375        result = utils.ForwardedPorts(vnc_port=None, adb_port=None)
376        self.Patch(subprocess, "check_call", side_effect=call_side_effect)
377        self.assertEqual(result, utils.AutoConnect(fake_ip_addr,
378                                                   fake_rsa_key_file,
379                                                   fake_target_vnc_port,
380                                                   target_adb_port,
381                                                   ssh_user))
382
383    # pylint: disable=protected-access,no-member
384    def testExtraArgsSSHTunnel(self):
385        """Tesg extra args will be the same with expanded args."""
386        fake_ip_addr = "1.1.1.1"
387        fake_rsa_key_file = "/tmp/rsa_file"
388        fake_target_vnc_port = 8888
389        target_adb_port = 9999
390        ssh_user = "fake_user"
391        fake_port = 12345
392        self.Patch(utils, "PickFreePort", return_value=fake_port)
393        self.Patch(utils, "_ExecuteCommand")
394        self.Patch(subprocess, "check_call", return_value=True)
395        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
396        utils.AutoConnect(ip_addr=fake_ip_addr,
397                          rsa_key_file=fake_rsa_key_file,
398                          target_vnc_port=fake_target_vnc_port,
399                          target_adb_port=target_adb_port,
400                          ssh_user=ssh_user,
401                          client_adb_port=fake_port,
402                          extra_args_ssh_tunnel=extra_args_ssh_tunnel)
403        args_list = ["-i", "/tmp/rsa_file",
404                     "-o", "UserKnownHostsFile=/dev/null",
405                     "-o", "StrictHostKeyChecking=no",
406                     "-L", "12345:127.0.0.1:8888",
407                     "-L", "12345:127.0.0.1:9999",
408                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
409                     "-o", "command=shell %s %h",
410                     "-o", "command1=ls -la"]
411        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
412        self.assertEqual(first_call_args[1], args_list)
413
414    # pylint: disable=protected-access, no-member
415    def testCleanupSSVncviwer(self):
416        """test cleanup ssvnc viewer."""
417        fake_vnc_port = 9999
418        fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % {
419            "vnc_port": fake_vnc_port}
420        self.Patch(utils, "IsCommandRunning", return_value=True)
421        self.Patch(subprocess, "check_call", return_value=True)
422        utils.CleanupSSVncviewer(fake_vnc_port)
423        subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern])
424
425        subprocess.check_call.call_count = 0
426        self.Patch(utils, "IsCommandRunning", return_value=False)
427        utils.CleanupSSVncviewer(fake_vnc_port)
428        subprocess.check_call.assert_not_called()
429
430
431if __name__ == "__main__":
432    unittest.main()
433