1#
2# Copyright (C) 2016 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import logging
18import socket
19import socketserver
20import threading
21
22from vts.runners.host import errors
23from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg
24from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg
25from vts.utils.python.mirror import pb2py
26
27_functions = dict()  # Dictionary to hold function pointers
28
29
30class CallbackServerError(errors.VtsError):
31    """Raised when an error occurs in VTS TCP server."""
32
33
34class CallbackRequestHandler(socketserver.StreamRequestHandler):
35    """The request handler class for our server."""
36
37    def handle(self):
38        """Receives requests from clients.
39
40        When a callback happens on the target side, a request message is posted
41        to the host side and is handled here. The message is parsed and the
42        appropriate callback function on the host side is called.
43        """
44        header = self.rfile.readline().strip()
45        try:
46            len = int(header)
47        except ValueError:
48            if header:
49                logging.exception("Unable to convert '%s' into an integer, which "
50                                  "is required for reading the next message." %
51                                  header)
52                raise
53            else:
54                logging.error('CallbackRequestHandler received empty message header. Skipping...')
55                return
56        # Read the request message.
57        received_data = self.rfile.read(len)
58        logging.debug("Received callback message: %s", received_data)
59        request_message = SysMsg.AndroidSystemCallbackRequestMessage()
60        request_message.ParseFromString(received_data)
61        logging.debug('Handling callback ID: %s', request_message.id)
62        response_message = SysMsg.AndroidSystemCallbackResponseMessage()
63        # Call the appropriate callback function and construct the response
64        # message.
65        if request_message.id in _functions:
66            callback_args = []
67            for arg in request_message.arg:
68                callback_args.append(pb2py.Convert(arg))
69            args = tuple(callback_args)
70            _functions[request_message.id](*args)
71            response_message.response_code = SysMsg.SUCCESS
72        else:
73            logging.error("Callback function ID %s is not registered!",
74                          request_message.id)
75            response_message.response_code = SysMsg.FAIL
76
77        # send the response back to client
78        message = response_message.SerializeToString()
79        # self.request is the TCP socket connected to the client
80        self.request.sendall(message)
81
82
83class CallbackServer(object):
84    """This class creates TCPServer in separate thread.
85
86    Attributes:
87        _server: an instance of socketserver.TCPServer.
88        _port: this variable maintains the port number used in creating
89               the server connection.
90        _ip: variable to hold the IP Address of the host.
91        _hostname: IP Address to which initial connection is made.
92    """
93
94    def __init__(self):
95        self._server = None
96        self._port = 0  # Port 0 means to select an arbitrary unused port
97        self._ip = ""  # Used to store the IP address for the server
98        self._hostname = "localhost"  # IP address to which initial connection is made
99
100    def RegisterCallback(self, callback_func):
101        """Registers a callback function.
102
103        Args:
104            callback_func: The function to register.
105
106        Returns:
107            string, Id of the registered callback function.
108
109        Raises:
110            CallbackServerError is raised if the func_id is already registered.
111        """
112        if self.GetCallbackId(callback_func):
113            raise CallbackServerError("Function is already registered")
114        id = 0
115        if _functions:
116            id = int(max(_functions, key=int)) + 1
117        _functions[str(id)] = callback_func
118        return str(id)
119
120    def UnregisterCallback(self, func_id):
121        """Removes a callback function from the registry.
122
123        Args:
124            func_id: The ID of the callback function to remove.
125
126        Raises:
127            CallbackServerError is raised if the func_id is not registered.
128        """
129        try:
130            _functions.pop(func_id)
131        except KeyError:
132            raise CallbackServerError(
133                "Can't remove function ID '%s', which is not registered." %
134                func_id)
135
136    def GetCallbackId(self, callback_func):
137        """Get ID of the callback function.  Registers a callback function.
138
139        Args:
140            callback_func: The function to register.
141
142        Returns:
143            string, Id of the callback function if found, None otherwise.
144        """
145        # dict _functions is { id : func }
146        for id, func in _functions.items():
147          if func is callback_func:
148            return id
149
150    def Start(self, port=0):
151        """Starts the server.
152
153        Args:
154            port: integer, number of the port on which the server listens.
155                  Default is 0, which means auto-select a port available.
156
157        Returns:
158            IP Address, port number
159
160        Raises:
161            CallbackServerError is raised if the server fails to start.
162        """
163        try:
164            self._server = socketserver.TCPServer(
165                (self._hostname, port), CallbackRequestHandler)
166            self._ip, self._port = self._server.server_address
167
168            # Start a thread with the server.
169            # Each request will be handled in a child thread.
170            server_thread = threading.Thread(target=self._server.serve_forever)
171            server_thread.daemon = True
172            server_thread.start()
173            logging.debug('TcpServer %s started (%s:%s)', server_thread.name,
174                         self._ip, self._port)
175            return self._ip, self._port
176        except (RuntimeError, IOError, socket.error) as e:
177            logging.exception(e)
178            raise CallbackServerError(
179                'Failed to start CallbackServer on (%s:%s).' %
180                (self._hostname, port))
181
182    def Stop(self):
183        """Stops the server.
184        """
185        self._server.shutdown()
186        self._server.server_close()
187
188    @property
189    def ip(self):
190        return self._ip
191
192    @property
193    def port(self):
194        return self._port
195