1# Copyright 2019 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for pull."""
15import unittest
16
17import os
18import tempfile
19
20from unittest import mock
21
22from acloud import errors
23from acloud.internal import constants
24from acloud.internal.lib import driver_test_lib
25from acloud.internal.lib import ssh
26from acloud.internal.lib import utils
27from acloud.list import list as list_instances
28from acloud.public import config
29from acloud.pull import pull
30
31
32class PullTest(driver_test_lib.BaseDriverTest):
33    """Test pull."""
34
35    # pylint: disable=no-member
36    def testPullFileFromInstance(self):
37        """test PullFileFromInstance."""
38        cfg = mock.MagicMock()
39        cfg.ssh_private_key_path = "fake_ssh_path"
40        cfg.extra_args_ssh_tunnel = ""
41        instance = mock.MagicMock()
42        instance.ip = "1.1.1.1"
43        # Multiple selected files case.
44        selected_files = ["file1.log", "file2.log"]
45        self.Patch(pull, "SelectLogFileToPull", return_value=selected_files)
46        self.Patch(pull, "GetDownloadLogFolder", return_value="fake_folder")
47        self.Patch(pull, "PullLogs")
48        self.Patch(pull, "DisplayLog")
49        pull.PullFileFromInstance(cfg, instance)
50        self.assertEqual(pull.DisplayLog.call_count, 0)
51
52        # Only one file selected case.
53        selected_files = ["file1.log"]
54        self.Patch(pull, "SelectLogFileToPull", return_value=selected_files)
55        pull.PullFileFromInstance(cfg, instance)
56        self.assertEqual(pull.DisplayLog.call_count, 1)
57
58    # pylint: disable=no-member
59    def testPullLogs(self):
60        """test PullLogs."""
61        _ssh = mock.MagicMock()
62        self.Patch(utils, "PrintColorString")
63        log_files = ["file1.log", "file2.log"]
64        download_folder = "/fake_folder"
65        pull.PullLogs(_ssh, log_files, download_folder)
66        self.assertEqual(_ssh.ScpPullFile.call_count, 2)
67        utils.PrintColorString.assert_called_once()
68
69    @mock.patch.object(ssh.Ssh, "Run")
70    def testDisplayLog(self, mock_ssh_run):
71        """Test DisplayLog."""
72        fake_ip = ssh.IP(external="1.1.1.1", internal="10.1.1.1")
73        _ssh = ssh.Ssh(ip=fake_ip,
74                       user=constants.GCE_USER,
75                       ssh_private_key_path="/fake/acloud_rea")
76        self.Patch(utils, "GetUserAnswerYes", return_value="Y")
77        log_file = "file1.log"
78        pull.DisplayLog(_ssh, log_file)
79        expected_cmd = "tail -f -n +1 %s" % log_file
80        mock_ssh_run.assert_has_calls([
81            mock.call(expected_cmd, show_output=True)])
82
83    def testGetDownloadLogFolder(self):
84        """test GetDownloadLogFolder."""
85        self.Patch(tempfile, "gettempdir", return_value="/tmp")
86        self.Patch(os.path, "exists", return_value=True)
87        instance = "instance"
88        expected_path = "/tmp/instance"
89        self.assertEqual(pull.GetDownloadLogFolder(instance), expected_path)
90
91    def testSelectLogFileToPull(self):
92        """test choose log files from the remote instance."""
93        _ssh = mock.MagicMock()
94
95        # Test only one log file case
96        log_files = ["file1.log"]
97        self.Patch(pull, "GetAllLogFilePaths", return_value=log_files)
98        expected_result = ["file1.log"]
99        self.assertEqual(pull.SelectLogFileToPull(_ssh), expected_result)
100
101        # Test no log files case
102        self.Patch(pull, "GetAllLogFilePaths", return_value=[])
103        with self.assertRaises(errors.CheckPathError):
104            pull.SelectLogFileToPull(_ssh)
105
106        # Test two log files case.
107        log_files = ["file1.log", "file2.log"]
108        choose_log = ["file2.log"]
109        self.Patch(pull, "GetAllLogFilePaths", return_value=log_files)
110        self.Patch(utils, "GetAnswerFromList", return_value=choose_log)
111        expected_result = ["file2.log"]
112        self.assertEqual(pull.SelectLogFileToPull(_ssh), expected_result)
113
114        # Test user provided file name exist.
115        log_files = ["/home/vsoc-01/cuttlefish_runtime/file1.log",
116                     "/home/vsoc-01/cuttlefish_runtime/file2.log"]
117        input_file = "file1.log"
118        self.Patch(pull, "GetAllLogFilePaths", return_value=log_files)
119        expected_result = ["/home/vsoc-01/cuttlefish_runtime/file1.log"]
120        self.assertEqual(pull.SelectLogFileToPull(_ssh, input_file), expected_result)
121
122        # Test user provided file name not exist.
123        log_files = ["/home/vsoc-01/cuttlefish_runtime/file1.log",
124                     "/home/vsoc-01/cuttlefish_runtime/file2.log"]
125        input_file = "not_exist.log"
126        self.Patch(pull, "GetAllLogFilePaths", return_value=log_files)
127        with self.assertRaises(errors.CheckPathError):
128            pull.SelectLogFileToPull(_ssh, input_file)
129
130    def testFilterLogfiles(self):
131        """test filer log file from black list."""
132        # Filter out file name is "kernel".
133        files = ["kernel.log", "logcat", "kernel"]
134        expected_result = ["kernel.log", "logcat"]
135        self.assertEqual(pull.FilterLogfiles(files), expected_result)
136
137        # Filter out file extension is ".img".
138        files = ["kernel.log", "system.img", "userdata.img", "launcher.log"]
139        expected_result = ["kernel.log", "launcher.log"]
140        self.assertEqual(pull.FilterLogfiles(files), expected_result)
141
142    @mock.patch.object(pull, "PullFileFromInstance")
143    def testRun(self, mock_pull_file):
144        """test Run."""
145        cfg = mock.MagicMock()
146        args = mock.MagicMock()
147        instance_obj = mock.MagicMock()
148        # Test case with provided instance name.
149        args.instance_name = "instance_1"
150        args.file_name = "file1.log"
151        args.no_prompt = True
152        self.Patch(config, "GetAcloudConfig", return_value=cfg)
153        self.Patch(list_instances, "GetInstancesFromInstanceNames",
154                   return_value=[instance_obj])
155        pull.Run(args)
156        mock_pull_file.assert_has_calls([
157            mock.call(cfg, instance_obj, args.file_name, args.no_prompt)])
158
159        # Test case for user select one instance to pull log.
160        selected_instance = mock.MagicMock()
161        self.Patch(list_instances, "ChooseOneRemoteInstance",
162                   return_value=selected_instance)
163        args.instance_name = None
164        pull.Run(args)
165        mock_pull_file.assert_has_calls([
166            mock.call(cfg, selected_instance, args.file_name, args.no_prompt)])
167
168
169if __name__ == '__main__':
170    unittest.main()
171