1# Copyright 2021 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for restart."""
15import unittest
16
17from unittest import mock
18
19from acloud.internal.lib import driver_test_lib
20from acloud.list import list as list_instances
21from acloud.public import config
22from acloud.restart import restart
23
24
25class RestartTest(driver_test_lib.BaseDriverTest):
26    """Test restart."""
27
28    @mock.patch.object(restart, "RestartFromInstance")
29    def testRun(self, mock_restart):
30        """test Run."""
31        cfg = mock.MagicMock()
32        args = mock.MagicMock()
33        instance_obj = mock.MagicMock()
34        # Test case with provided instance name.
35        args.instance_name = "instance_1"
36        args.instance_id = 1
37        args.powerwash = False
38        self.Patch(config, "GetAcloudConfig", return_value=cfg)
39        self.Patch(list_instances, "GetInstancesFromInstanceNames",
40                   return_value=[instance_obj])
41        restart.Run(args)
42        mock_restart.assert_has_calls([
43            mock.call(cfg, instance_obj, args.instance_id, args.powerwash)])
44
45        # Test case for user select one instance to restart AVD.
46        selected_instance = mock.MagicMock()
47        self.Patch(list_instances, "ChooseOneRemoteInstance",
48                   return_value=selected_instance)
49        args.instance_name = None
50        restart.Run(args)
51        mock_restart.assert_has_calls([
52            mock.call(cfg, selected_instance, args.instance_id, args.powerwash)])
53
54
55if __name__ == '__main__':
56    unittest.main()
57