1#!/usr/bin/env python3
2#
3#   Copyright 2018 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16import unittest
17from unittest import TestCase
18
19import sys
20from acts.event import subscription_bundle
21from acts.event.decorators import subscribe
22from acts.event.decorators import subscribe_static
23from acts.event.event import Event
24from acts.event.subscription_bundle import SubscriptionBundle
25from mock import Mock
26from mock import patch
27
28
29class SubscriptionBundleTest(TestCase):
30    """Tests the SubscriptionBundle class."""
31
32    def test_add_calls_add_subscription_properly(self):
33        """Tests that the convenience function add() calls add_subscription."""
34        event = object()
35        func = object()
36        event_filter = object()
37        order = object()
38        package = SubscriptionBundle()
39        package.add_subscription = Mock()
40
41        package.add(event, func, event_filter=event_filter, order=order)
42
43        self.assertEqual(package.add_subscription.call_count, 1)
44        subscription = package.add_subscription.call_args[0][0]
45        self.assertEqual(subscription._event_type, event)
46        self.assertEqual(subscription._func, func)
47        self.assertEqual(subscription._event_filter, event_filter)
48        self.assertEqual(subscription.order, order)
49
50    @patch('acts.event.event_bus.register_subscription')
51    def test_add_subscription_registers_sub_if_package_is_registered(
52            self, register_subscription):
53        """Tests that add_subscription registers the subscription if the
54        SubscriptionBundle is already registered."""
55        package = SubscriptionBundle()
56        package._registered = True
57        mock_subscription = Mock()
58
59        package.add_subscription(mock_subscription)
60
61        self.assertEqual(register_subscription.call_count, 1)
62        register_subscription.assert_called_with(mock_subscription)
63
64    def test_add_subscription_adds_to_subscriptions(self):
65        """Tests add_subscription adds the subscription to subscriptions."""
66        mock_subscription = Mock()
67        package = SubscriptionBundle()
68
69        package.add_subscription(mock_subscription)
70
71        self.assertTrue(mock_subscription in package.subscriptions.keys())
72
73    def test_remove_subscription_removes_subscription_from_subscriptions(self):
74        """Tests remove_subscription removes the given subscription from the
75        subscriptions dictionary."""
76        mock_subscription = Mock()
77        package = SubscriptionBundle()
78        package.subscriptions[mock_subscription] = id(mock_subscription)
79
80        package.remove_subscription(mock_subscription)
81
82        self.assertTrue(mock_subscription not in package.subscriptions.keys())
83
84    @patch('acts.event.event_bus.unregister')
85    def test_remove_subscription_unregisters_subscription(self, unregister):
86        """Tests that removing a subscription will also unregister it if the
87        SubscriptionBundle is registered."""
88        mock_subscription = Mock()
89        package = SubscriptionBundle()
90        package._registered = True
91        package.subscriptions[mock_subscription] = id(mock_subscription)
92
93        package.remove_subscription(mock_subscription)
94
95        self.assertEqual(unregister.call_count, 1)
96        unregistered_obj = unregister.call_args[0][0]
97        self.assertTrue(unregistered_obj == id(mock_subscription) or
98                        unregistered_obj == mock_subscription)
99
100    @patch('acts.event.event_bus.register_subscription')
101    def test_register_registers_all_subscriptions(self, register_subscription):
102        """Tests register() registers all subscriptions within the bundle."""
103        mock_subscription_list = [Mock(), Mock(), Mock()]
104        package = SubscriptionBundle()
105        package._registered = False
106        for subscription in mock_subscription_list:
107            package.subscriptions[subscription] = None
108
109        package.register()
110
111        self.assertEqual(register_subscription.call_count,
112                         len(mock_subscription_list))
113        args = {args[0] for args, _ in register_subscription.call_args_list}
114        for subscription in mock_subscription_list:
115            self.assertTrue(subscription in args or id(subscription) in args)
116
117    @patch('acts.event.event_bus.unregister')
118    def test_register_registers_all_subscriptions(self, unregister):
119        """Tests register() registers all subscriptions within the bundle."""
120        mock_subscription_list = [Mock(), Mock(), Mock()]
121        package = SubscriptionBundle()
122        package._registered = True
123        for subscription in mock_subscription_list:
124            package.subscriptions[subscription] = id(subscription)
125
126        package.unregister()
127
128        self.assertEqual(unregister.call_count, len(mock_subscription_list))
129        args = {args[0] for args, _ in unregister.call_args_list}
130        for subscription in mock_subscription_list:
131            self.assertTrue(subscription in args or id(subscription) in args)
132
133
134class SubscriptionBundleStaticFunctions(TestCase):
135    """Tests the static functions found in subscription_bundle.py"""
136
137    @staticmethod
138    @subscribe_static(Event)
139    def static_listener_1():
140        pass
141
142    @staticmethod
143    @subscribe_static(Event)
144    def static_listener_2():
145        pass
146
147    @subscribe(Event)
148    def instance_listener_1(self):
149        pass
150
151    @subscribe(Event)
152    def instance_listener_2(self):
153        pass
154
155    def test_create_from_static(self):
156        """Tests create_from_static gets all StaticSubscriptionHandles."""
157        cls = self.__class__
158        bundle = subscription_bundle.create_from_static(cls)
159
160        self.assertEqual(len(bundle.subscriptions), 2)
161        keys = bundle.subscriptions.keys()
162        self.assertIn(cls.static_listener_1.subscription, keys)
163        self.assertIn(cls.static_listener_2.subscription, keys)
164
165    def test_create_from_instance(self):
166        """Tests create_from_instance gets all InstanceSubscriptionHandles."""
167        bundle = subscription_bundle.create_from_instance(self)
168
169        self.assertEqual(len(bundle.subscriptions), 2)
170        keys = bundle.subscriptions.keys()
171        self.assertIn(self.instance_listener_1.subscription, keys)
172        self.assertIn(self.instance_listener_2.subscription, keys)
173
174
175@subscribe_static(Event)
176def static_listener_1():
177    pass
178
179
180class SubscribeStaticModuleLevelTest(TestCase):
181    def test_create_from_static(self):
182        """Tests create_from_static gets all StaticSubscriptionHandles."""
183        bundle = subscription_bundle.create_from_static(
184            sys.modules[self.__module__])
185
186        self.assertEqual(len(bundle.subscriptions), 1)
187        keys = bundle.subscriptions.keys()
188        self.assertIn(static_listener_1.subscription, keys)
189
190
191if __name__ == '__main__':
192    unittest.main()
193