1#
2# Copyright (C) 2016 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import socket
18import unittest
19import logging
20import errno
21from socket import error as socket_error
22
23from vts.runners.host import errors
24from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg_pb2
25from vts.runners.host.tcp_server import callback_server
26
27HOST, PORT = "localhost", 0
28ERROR_PORT = 380  # port at which we test the error case.
29
30
31class TestMethods(unittest.TestCase):
32    """This class defines unit test methods.
33
34    The common scenarios are when we wish to test the whether we are able to
35    receive the expected data from the server; and whether we receive the
36    correct error when we try to connect to server from a wrong port.
37
38    Attributes:
39        _callback_server: an instance of CallbackServer that is used to
40                         start and stop the TCP server.
41        _counter: This is used to keep track of number of calls made to the
42                  callback function.
43    """
44    _callback_server = None
45    _counter = 0
46
47    def setUp(self):
48        """This function initiates starting the server in CallbackServer."""
49        self._callback_server = callback_server.CallbackServer()
50        self._callback_server.Start()
51
52    def tearDown(self):
53        """To initiate shutdown of the server.
54
55        This function calls the callback_server.CallbackServer.Stop which
56        shutdowns the server.
57        """
58        self._callback_server.Stop()
59
60    def DoErrorCase(self):
61        """Unit test for Error case.
62
63        This function tests the cases that throw exception.
64        e.g sending requests to port 25.
65
66        Raises:
67            ConnectionRefusedError: ConnectionRefusedError occurred in
68            test_ErrorCase().
69        """
70        host = self._callback_server.ip
71
72        # Create a socket (SOCK_STREAM means a TCP socket)
73        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
74
75        try:
76            # Connect to server; this should result in Connection refused error
77            sock.connect((host, ERROR_PORT))
78        except socket_error as e:
79            # We are comparing the error number of the error we expect and
80            # the error that we get.
81            # Test fails if ConnectionRefusedError is not raised at this step.
82            if e.errno == errno.ECONNREFUSED:
83                raise errors.ConnectionRefusedError  # Test is a success here
84            else:
85                raise e  # Test fails, since ConnectionRefusedError was expected
86        finally:
87            sock.close()
88
89    def ConnectToServer(self, func_id):
90        """This function creates a connection to TCP server and sends/receives
91            message.
92
93        Args:
94            func_id: This is the unique key corresponding to a function and
95                also the id field of the request_message that we send to the
96                server.
97
98        Returns:
99            response_message: The object that the TCP host returns.
100
101        Raises:
102            TcpServerConnectionError: Exception occurred while stopping server.
103        """
104        # This object is sent to the TCP host
105        request_message = SysMsg_pb2.AndroidSystemCallbackRequestMessage()
106        request_message.id = func_id
107
108        #  The response in string format that we receive from host
109        received_message = ""
110
111        # The final object that this function returns
112        response_message = SysMsg_pb2.AndroidSystemCallbackResponseMessage()
113
114        # Create a socket (SOCK_STREAM means a TCP socket)
115        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
116        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
117        host = self._callback_server.ip
118        port = self._callback_server.port
119        logging.debug('Sending Request to host %s using port %s', host, port)
120
121        try:
122            # Connect to server and send request_message
123            sock.connect((host, port))
124
125            message = request_message.SerializeToString()
126            sock.sendall(str(len(message)) + "\n" + message)
127            logging.debug("Sent: %s", message)
128
129            # Receive request_message from the server and shut down
130            received_message = sock.recv(1024)
131            response_message.ParseFromString(received_message)
132            logging.debug('Received: %s', received_message)
133        except socket_error as e:
134            logging.error(e)
135            raise errors.TcpServerConnectionError('Exception occurred.')
136        finally:
137            sock.close()
138
139        return response_message
140
141    def testDoErrorCase(self):
142        """Unit test for error cases."""
143        with self.assertRaises(errors.ConnectionRefusedError):
144            self.DoErrorCase()
145
146    def testCallback(self):
147        """Tests two callback use cases."""
148        self.TestNormalCase()
149        self.TestDoRegisterCallback()
150
151    def TestNormalCase(self):
152        """Tests the normal request to TCPServer.
153
154        This function sends the request to the Tcp server where the request
155        should be a success.
156
157        This function also checks the register callback feature by ensuring that
158        callback_func() is called and the value of the global counter is
159        increased by one.
160        """
161        def callback_func():
162            self._counter += 1
163
164        # Function should be registered with RegisterCallback
165        func_id = self._callback_server.RegisterCallback(callback_func)
166        self.assertEqual(func_id, '0')
167
168        # Capture the previous value of global counter
169        prev_value = self._counter
170
171        # Connect to server
172        response_message = self.ConnectToServer(func_id)
173
174        # Confirm whether the callback_func() was called thereby increasing
175        # value of global counter by 1
176        self.assertEqual(self._counter, prev_value + 1)
177
178        # Also confirm if query resulted in a success
179        self.assertEqual(response_message.response_code, SysMsg_pb2.SUCCESS)
180
181    def TestDoRegisterCallback(self):
182        """Checks the register callback functionality of the Server.
183
184        This function checks whether the value of global counter remains same
185        if function is not registered. It also checks whether it's incremented
186        by 1 when the function is registered.
187        """
188        def callback_func():
189            self._counter += 1
190
191        # Capture the previous value of global counter
192        prev_value = self._counter
193
194        # Function should be registered with RegisterCallback
195        func_id = self._callback_server.RegisterCallback(callback_func)
196        found_func_id = self._callback_server.GetCallbackId(callback_func)
197        self.assertEqual(func_id, found_func_id)
198
199        # Connect to server
200        response_message = self.ConnectToServer(func_id)
201
202        # Confirm whether the callback_func() was not called.
203        self.assertEqual(self._counter, prev_value + 1)
204
205        # also confirm the error message
206        self.assertEqual(response_message.response_code, SysMsg_pb2.SUCCESS)
207
208        # Now unregister the function and check again
209        # Function should be unregistered with UnegisterCallback
210        # and the key should also be present
211        self._callback_server.UnregisterCallback(func_id)
212
213        # Capture the previous value of global counter
214        prev_value = self._counter
215
216        # Connect to server
217        response_message = self.ConnectToServer(func_id)
218
219        # Confirm whether the callback_func() was not called.
220        self.assertEqual(self._counter, prev_value)
221
222        # also confirm the error message
223        self.assertEqual(response_message.response_code, SysMsg_pb2.FAIL)
224
225if __name__ == '__main__':
226    unittest.main()
227