1#!/usr/bin/env python
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Tests for acloud.internal.lib.utils."""
17
18import collections
19import errno
20import getpass
21import grp
22import os
23import shutil
24import subprocess
25import tempfile
26import time
27import webbrowser
28
29import unittest
30
31from unittest import mock
32import six
33
34from acloud import errors
35from acloud.internal.lib import driver_test_lib
36from acloud.internal.lib import utils
37
38
39GroupInfo = collections.namedtuple("GroupInfo", [
40    "gr_name",
41    "gr_passwd",
42    "gr_gid",
43    "gr_mem"])
44
45# Tkinter may not be supported so mock it out.
46try:
47    import Tkinter
48except ImportError:
49    Tkinter = mock.Mock()
50
51
52class FakeTkinter:
53    """Fake implementation of Tkinter.Tk()"""
54
55    def __init__(self, width=None, height=None):
56        self.width = width
57        self.height = height
58
59    # pylint: disable=invalid-name
60    def winfo_screenheight(self):
61        """Return the screen height."""
62        return self.height
63
64    # pylint: disable=invalid-name
65    def winfo_screenwidth(self):
66        """Return the screen width."""
67        return self.width
68
69
70# pylint: disable=too-many-public-methods
71class UtilsTest(driver_test_lib.BaseDriverTest):
72    """Test Utils."""
73
74    def TestTempDirSuccess(self):
75        """Test create a temp dir."""
76        self.Patch(os, "chmod")
77        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
78        self.Patch(shutil, "rmtree")
79        with utils.TempDir():
80            pass
81        # Verify.
82        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
83        shutil.rmtree.assert_called_with("/tmp/tempdir")  # pylint: disable=no-member
84
85    def TestTempDirExceptionRaised(self):
86        """Test create a temp dir and exception is raised within with-clause."""
87        self.Patch(os, "chmod")
88        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
89        self.Patch(shutil, "rmtree")
90
91        class ExpectedException(Exception):
92            """Expected exception."""
93
94        def _Call():
95            with utils.TempDir():
96                raise ExpectedException("Expected exception.")
97
98        # Verify. ExpectedException should be raised.
99        self.assertRaises(ExpectedException, _Call)
100        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
101        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
102
103    def testTempDirWhenDeleteTempDirNoLongerExist(self):  # pylint: disable=invalid-name
104        """Test create a temp dir and dir no longer exists during deletion."""
105        self.Patch(os, "chmod")
106        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
107        expected_error = EnvironmentError()
108        expected_error.errno = errno.ENOENT
109        self.Patch(shutil, "rmtree", side_effect=expected_error)
110
111        def _Call():
112            with utils.TempDir():
113                pass
114
115        # Verify no exception should be raised when rmtree raises
116        # EnvironmentError with errno.ENOENT, i.e.
117        # directory no longer exists.
118        _Call()
119        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
120        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
121
122    def testTempDirWhenDeleteEncounterError(self):
123        """Test create a temp dir and encoutered error during deletion."""
124        self.Patch(os, "chmod")
125        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
126        expected_error = OSError("Expected OS Error")
127        self.Patch(shutil, "rmtree", side_effect=expected_error)
128
129        def _Call():
130            with utils.TempDir():
131                pass
132
133        # Verify OSError should be raised.
134        self.assertRaises(OSError, _Call)
135        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
136        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
137
138    def testTempDirOrininalErrorRaised(self):
139        """Test original error is raised even if tmp dir deletion failed."""
140        self.Patch(os, "chmod")
141        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
142        expected_error = OSError("Expected OS Error")
143        self.Patch(shutil, "rmtree", side_effect=expected_error)
144
145        class ExpectedException(Exception):
146            """Expected exception."""
147
148        def _Call():
149            with utils.TempDir():
150                raise ExpectedException("Expected Exception")
151
152        # Verify.
153        # ExpectedException should be raised, and OSError
154        # should not be raised.
155        self.assertRaises(ExpectedException, _Call)
156        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
157        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
158
159    def testCreateSshKeyPairKeyAlreadyExists(self):  #pylint: disable=invalid-name
160        """Test when the key pair already exists."""
161        public_key = "/fake/public_key"
162        private_key = "/fake/private_key"
163        self.Patch(os.path, "exists", side_effect=[True, True])
164        self.Patch(subprocess, "check_call")
165        self.Patch(os, "makedirs", return_value=True)
166        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
167        self.assertEqual(subprocess.check_call.call_count, 0)  #pylint: disable=no-member
168
169    def testCreateSshKeyPairKeyAreCreated(self):
170        """Test when the key pair created."""
171        public_key = "/fake/public_key"
172        private_key = "/fake/private_key"
173        self.Patch(os.path, "exists", return_value=False)
174        self.Patch(os, "makedirs", return_value=True)
175        self.Patch(subprocess, "check_call")
176        self.Patch(os, "rename")
177        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
178        self.assertEqual(subprocess.check_call.call_count, 1)  #pylint: disable=no-member
179        subprocess.check_call.assert_called_with(  #pylint: disable=no-member
180            utils.SSH_KEYGEN_CMD +
181            ["-C", getpass.getuser(), "-f", private_key],
182            stdout=mock.ANY,
183            stderr=mock.ANY)
184
185    def testCreatePublicKeyAreCreated(self):
186        """Test when the PublicKey created."""
187        public_key = "/fake/public_key"
188        private_key = "/fake/private_key"
189        self.Patch(os.path, "exists", side_effect=[False, True, True])
190        self.Patch(os, "makedirs", return_value=True)
191        mock_open = mock.mock_open(read_data=public_key)
192        self.Patch(subprocess, "check_output")
193        self.Patch(os, "rename")
194        with mock.patch.object(six.moves.builtins, "open", mock_open):
195            utils.CreateSshKeyPairIfNotExist(private_key, public_key)
196        self.assertEqual(subprocess.check_output.call_count, 1)  #pylint: disable=no-member
197        subprocess.check_output.assert_called_with(  #pylint: disable=no-member
198            utils.SSH_KEYGEN_PUB_CMD +["-f", private_key])
199
200    def TestRetryOnException(self):
201        """Test Retry."""
202
203        def _IsValueError(exc):
204            return isinstance(exc, ValueError)
205
206        num_retry = 5
207
208        @utils.RetryOnException(_IsValueError, num_retry)
209        def _RaiseAndRetry(sentinel):
210            sentinel.alert()
211            raise ValueError("Fake error.")
212
213        sentinel = mock.MagicMock()
214        self.assertRaises(ValueError, _RaiseAndRetry, sentinel)
215        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
216
217    def testRetryExceptionType(self):
218        """Test RetryExceptionType function."""
219
220        def _RaiseAndRetry(sentinel):
221            sentinel.alert()
222            raise ValueError("Fake error.")
223
224        num_retry = 5
225        sentinel = mock.MagicMock()
226        self.assertRaises(
227            ValueError,
228            utils.RetryExceptionType, (KeyError, ValueError),
229            num_retry,
230            _RaiseAndRetry,
231            0, # sleep_multiplier
232            1, # retry_backoff_factor
233            sentinel=sentinel)
234        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
235
236    def testRetry(self):
237        """Test Retry."""
238        mock_sleep = self.Patch(time, "sleep")
239
240        def _RaiseAndRetry(sentinel):
241            sentinel.alert()
242            raise ValueError("Fake error.")
243
244        num_retry = 5
245        sentinel = mock.MagicMock()
246        self.assertRaises(
247            ValueError,
248            utils.RetryExceptionType, (ValueError, KeyError),
249            num_retry,
250            _RaiseAndRetry,
251            1, # sleep_multiplier
252            2, # retry_backoff_factor
253            sentinel=sentinel)
254
255        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
256        mock_sleep.assert_has_calls(
257            [
258                mock.call(1),
259                mock.call(2),
260                mock.call(4),
261                mock.call(8),
262                mock.call(16)
263            ])
264
265    @mock.patch.object(six.moves, "input")
266    def testGetAnswerFromList(self, mock_raw_input):
267        """Test GetAnswerFromList."""
268        answer_list = ["image1.zip", "image2.zip", "image3.zip"]
269        mock_raw_input.return_value = 0
270        with self.assertRaises(SystemExit):
271            utils.GetAnswerFromList(answer_list)
272        mock_raw_input.side_effect = [1, 2, 3, 4]
273        self.assertEqual(utils.GetAnswerFromList(answer_list),
274                         ["image1.zip"])
275        self.assertEqual(utils.GetAnswerFromList(answer_list),
276                         ["image2.zip"])
277        self.assertEqual(utils.GetAnswerFromList(answer_list),
278                         ["image3.zip"])
279        self.assertEqual(utils.GetAnswerFromList(answer_list,
280                                                 enable_choose_all=True),
281                         answer_list)
282
283    @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.")
284    @mock.patch.object(Tkinter, "Tk")
285    def testCalculateVNCScreenRatio(self, mock_tk):
286        """Test Calculating the scale ratio of VNC display."""
287        # Get scale-down ratio if screen height is smaller than AVD height.
288        mock_tk.return_value = FakeTkinter(height=800, width=1200)
289        avd_h = 1920
290        avd_w = 1080
291        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4)
292
293        # Get scale-down ratio if screen width is smaller than AVD width.
294        mock_tk.return_value = FakeTkinter(height=800, width=1200)
295        avd_h = 900
296        avd_w = 1920
297        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
298
299        # Scale ratio = 1 if screen is larger than AVD.
300        mock_tk.return_value = FakeTkinter(height=1080, width=1920)
301        avd_h = 800
302        avd_w = 1280
303        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1)
304
305        # Get the scale if ratio of width is smaller than the
306        # ratio of height.
307        mock_tk.return_value = FakeTkinter(height=1200, width=800)
308        avd_h = 1920
309        avd_w = 1080
310        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
311
312    def testCheckUserInGroups(self):
313        """Test CheckUserInGroups."""
314        self.Patch(getpass, "getuser", return_value="user_0")
315        self.Patch(grp, "getgrall", return_value=[
316            GroupInfo("fake_group1", "passwd_1", 0, ["user_1", "user_2"]),
317            GroupInfo("fake_group2", "passwd_2", 1, ["user_1", "user_2"])])
318        self.Patch(grp, "getgrnam", return_value=GroupInfo(
319            "fake_group1", "passwd_1", 0, ["user_1", "user_2"]))
320        # Test Group name doesn't exist.
321        self.assertFalse(utils.CheckUserInGroups(["Non_exist_group"]))
322
323        # Test User isn't in group.
324        self.assertFalse(utils.CheckUserInGroups(["fake_group1"]))
325
326        # Test User is in group.
327        self.Patch(getpass, "getuser", return_value="user_1")
328        self.assertTrue(utils.CheckUserInGroups(["fake_group1"]))
329
330    @mock.patch.object(utils, "CheckUserInGroups")
331    def testAddUserGroupsToCmd(self, mock_user_group):
332        """Test AddUserGroupsToCmd."""
333        command = "test_command"
334        groups = ["group1", "group2"]
335        # Don't add user group in command
336        mock_user_group.return_value = True
337        expected_value = "test_command"
338        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
339                                                                  groups))
340
341        # Add user group in command
342        mock_user_group.return_value = False
343        expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF"
344        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
345                                                                  groups))
346
347    # pylint: disable=invalid-name
348    def testTimeoutException(self):
349        """Test TimeoutException."""
350        @utils.TimeoutException(1, "should time out")
351        def functionThatWillTimeOut():
352            """Test decorator of @utils.TimeoutException should timeout."""
353            time.sleep(5)
354
355        self.assertRaises(errors.FunctionTimeoutError,
356                          functionThatWillTimeOut)
357
358
359    def testTimeoutExceptionNoTimeout(self):
360        """Test No TimeoutException."""
361        @utils.TimeoutException(5, "shouldn't time out")
362        def functionThatShouldNotTimeout():
363            """Test decorator of @utils.TimeoutException shouldn't timeout."""
364            return None
365        try:
366            functionThatShouldNotTimeout()
367        except errors.FunctionTimeoutError:
368            self.fail("shouldn't timeout")
369
370    def testAutoConnectCreateSSHTunnelFail(self):
371        """Test auto connect."""
372        fake_ip_addr = "1.1.1.1"
373        fake_rsa_key_file = "/tmp/rsa_file"
374        fake_target_vnc_port = 8888
375        target_adb_port = 9999
376        ssh_user = "fake_user"
377        call_side_effect = subprocess.CalledProcessError(123, "fake",
378                                                         "fake error")
379        result = utils.ForwardedPorts(vnc_port=None, adb_port=None)
380        self.Patch(subprocess, "check_call", side_effect=call_side_effect)
381        self.assertEqual(result, utils.AutoConnect(fake_ip_addr,
382                                                   fake_rsa_key_file,
383                                                   fake_target_vnc_port,
384                                                   target_adb_port,
385                                                   ssh_user))
386
387    # pylint: disable=protected-access,no-member
388    def testExtraArgsSSHTunnel(self):
389        """Test extra args will be the same with expanded args."""
390        fake_ip_addr = "1.1.1.1"
391        fake_rsa_key_file = "/tmp/rsa_file"
392        fake_target_vnc_port = 8888
393        target_adb_port = 9999
394        ssh_user = "fake_user"
395        fake_port = 12345
396        self.Patch(utils, "PickFreePort", return_value=fake_port)
397        self.Patch(utils, "_ExecuteCommand")
398        self.Patch(subprocess, "check_call", return_value=True)
399        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
400        utils.AutoConnect(ip_addr=fake_ip_addr,
401                          rsa_key_file=fake_rsa_key_file,
402                          target_vnc_port=fake_target_vnc_port,
403                          target_adb_port=target_adb_port,
404                          ssh_user=ssh_user,
405                          client_adb_port=fake_port,
406                          extra_args_ssh_tunnel=extra_args_ssh_tunnel)
407        args_list = ["-i", "/tmp/rsa_file",
408                     "-o", "UserKnownHostsFile=/dev/null",
409                     "-o", "StrictHostKeyChecking=no",
410                     "-L", "12345:127.0.0.1:9999",
411                     "-L", "12345:127.0.0.1:8888",
412                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
413                     "-o", "command=shell %s %h",
414                     "-o", "command1=ls -la"]
415        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
416        self.assertEqual(first_call_args[1], args_list)
417
418    # pylint: disable=protected-access,no-member
419    def testEstablishWebRTCSshTunnel(self):
420        """Test establish WebRTC ssh tunnel."""
421        fake_ip_addr = "1.1.1.1"
422        fake_rsa_key_file = "/tmp/rsa_file"
423        ssh_user = "fake_user"
424        self.Patch(utils, "ReleasePort")
425        self.Patch(utils, "_ExecuteCommand")
426        self.Patch(subprocess, "check_call", return_value=True)
427        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
428        utils.EstablishWebRTCSshTunnel(
429            ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file,
430            ssh_user=ssh_user, extra_args_ssh_tunnel=None)
431        args_list = ["-i", "/tmp/rsa_file",
432                     "-o", "UserKnownHostsFile=/dev/null",
433                     "-o", "StrictHostKeyChecking=no",
434                     "-L", "8443:127.0.0.1:8443",
435                     "-L", "15550:127.0.0.1:15550",
436                     "-L", "15551:127.0.0.1:15551",
437                     "-N", "-f", "-l", "fake_user", "1.1.1.1"]
438        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
439        self.assertEqual(first_call_args[1], args_list)
440
441        extra_args_ssh_tunnel = "-o command='shell %s %h'"
442        utils.EstablishWebRTCSshTunnel(
443            ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file,
444            ssh_user=ssh_user, extra_args_ssh_tunnel=extra_args_ssh_tunnel)
445        args_list_with_extra_args = ["-i", "/tmp/rsa_file",
446                                     "-o", "UserKnownHostsFile=/dev/null",
447                                     "-o", "StrictHostKeyChecking=no",
448                                     "-L", "8443:127.0.0.1:8443",
449                                     "-L", "15550:127.0.0.1:15550",
450                                     "-L", "15551:127.0.0.1:15551",
451                                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
452                                     "-o", "command=shell %s %h"]
453        first_call_args = utils._ExecuteCommand.call_args_list[1][0]
454        self.assertEqual(first_call_args[1], args_list_with_extra_args)
455
456    # pylint: disable=protected-access, no-member
457    def testCleanupSSVncviwer(self):
458        """test cleanup ssvnc viewer."""
459        fake_vnc_port = 9999
460        fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % {
461            "vnc_port": fake_vnc_port}
462        self.Patch(utils, "IsCommandRunning", return_value=True)
463        self.Patch(subprocess, "check_call", return_value=True)
464        utils.CleanupSSVncviewer(fake_vnc_port)
465        subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern])
466
467        subprocess.check_call.call_count = 0
468        self.Patch(utils, "IsCommandRunning", return_value=False)
469        utils.CleanupSSVncviewer(fake_vnc_port)
470        subprocess.check_call.assert_not_called()
471
472    def testLaunchBrowserFromReport(self):
473        """test launch browser from report."""
474        self.Patch(webbrowser, "open_new_tab")
475        fake_report = mock.MagicMock(data={})
476
477        # test remote instance
478        self.Patch(os.environ, "get", return_value=True)
479        fake_report.data = {
480            "devices": [{"instance_name": "remote_cf_instance_name",
481                         "ip": "192.168.1.1",},],}
482
483        utils.LaunchBrowserFromReport(fake_report)
484        webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443")
485        webbrowser.open_new_tab.call_count = 0
486
487        # test local instance
488        fake_report.data = {
489            "devices": [{"instance_name": "local-instance1",
490                         "ip": "127.0.0.1:6250",},],}
491        utils.LaunchBrowserFromReport(fake_report)
492        webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443")
493        webbrowser.open_new_tab.call_count = 0
494
495        # verify terminal can't support launch webbrowser.
496        self.Patch(os.environ, "get", return_value=False)
497        utils.LaunchBrowserFromReport(fake_report)
498        self.assertEqual(webbrowser.open_new_tab.call_count, 0)
499
500    def testSetExecutable(self):
501        """test setting a file to be executable."""
502        with tempfile.NamedTemporaryFile(delete=True) as temp_file:
503            utils.SetExecutable(temp_file.name)
504            self.assertEqual(os.stat(temp_file.name).st_mode & 0o777, 0o755)
505
506    def testSetDirectoryTreeExecutable(self):
507        """test setting a file in a directory to be executable."""
508        with tempfile.TemporaryDirectory() as temp_dir:
509            subdir = os.path.join(temp_dir, "subdir")
510            file_path = os.path.join(subdir, "file")
511            os.makedirs(subdir)
512            with open(file_path, "w"):
513                pass
514            utils.SetDirectoryTreeExecutable(temp_dir)
515            self.assertEqual(os.stat(file_path).st_mode & 0o777, 0o755)
516
517
518if __name__ == "__main__":
519    unittest.main()
520