1# 2# Copyright (C) 2016 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17import logging 18import socket 19import socketserver 20import threading 21 22from vts.runners.host import errors 23from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg 24from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg 25from vts.utils.python.mirror import pb2py 26 27_functions = dict() # Dictionary to hold function pointers 28 29 30class CallbackServerError(errors.VtsError): 31 """Raised when an error occurs in VTS TCP server.""" 32 33 34class CallbackRequestHandler(socketserver.StreamRequestHandler): 35 """The request handler class for our server.""" 36 37 def handle(self): 38 """Receives requests from clients. 39 40 When a callback happens on the target side, a request message is posted 41 to the host side and is handled here. The message is parsed and the 42 appropriate callback function on the host side is called. 43 """ 44 header = self.rfile.readline().strip() 45 try: 46 len = int(header) 47 except ValueError: 48 if header: 49 logging.exception("Unable to convert '%s' into an integer, which " 50 "is required for reading the next message." % 51 header) 52 raise 53 else: 54 logging.error('CallbackRequestHandler received empty message header. Skipping...') 55 return 56 # Read the request message. 57 received_data = self.rfile.read(len) 58 logging.debug("Received callback message: %s", received_data) 59 request_message = SysMsg.AndroidSystemCallbackRequestMessage() 60 request_message.ParseFromString(received_data) 61 logging.debug('Handling callback ID: %s', request_message.id) 62 response_message = SysMsg.AndroidSystemCallbackResponseMessage() 63 # Call the appropriate callback function and construct the response 64 # message. 65 if request_message.id in _functions: 66 callback_args = [] 67 for arg in request_message.arg: 68 callback_args.append(pb2py.Convert(arg)) 69 args = tuple(callback_args) 70 _functions[request_message.id](*args) 71 response_message.response_code = SysMsg.SUCCESS 72 else: 73 logging.error("Callback function ID %s is not registered!", 74 request_message.id) 75 response_message.response_code = SysMsg.FAIL 76 77 # send the response back to client 78 message = response_message.SerializeToString() 79 # self.request is the TCP socket connected to the client 80 self.request.sendall(message) 81 82 83class CallbackServer(object): 84 """This class creates TCPServer in separate thread. 85 86 Attributes: 87 _server: an instance of socketserver.TCPServer. 88 _port: this variable maintains the port number used in creating 89 the server connection. 90 _ip: variable to hold the IP Address of the host. 91 _hostname: IP Address to which initial connection is made. 92 """ 93 94 def __init__(self): 95 self._server = None 96 self._port = 0 # Port 0 means to select an arbitrary unused port 97 self._ip = "" # Used to store the IP address for the server 98 self._hostname = "localhost" # IP address to which initial connection is made 99 100 def RegisterCallback(self, callback_func): 101 """Registers a callback function. 102 103 Args: 104 callback_func: The function to register. 105 106 Returns: 107 string, Id of the registered callback function. 108 109 Raises: 110 CallbackServerError is raised if the func_id is already registered. 111 """ 112 if self.GetCallbackId(callback_func): 113 raise CallbackServerError("Function is already registered") 114 id = 0 115 if _functions: 116 id = int(max(_functions, key=int)) + 1 117 _functions[str(id)] = callback_func 118 return str(id) 119 120 def UnregisterCallback(self, func_id): 121 """Removes a callback function from the registry. 122 123 Args: 124 func_id: The ID of the callback function to remove. 125 126 Raises: 127 CallbackServerError is raised if the func_id is not registered. 128 """ 129 try: 130 _functions.pop(func_id) 131 except KeyError: 132 raise CallbackServerError( 133 "Can't remove function ID '%s', which is not registered." % 134 func_id) 135 136 def GetCallbackId(self, callback_func): 137 """Get ID of the callback function. Registers a callback function. 138 139 Args: 140 callback_func: The function to register. 141 142 Returns: 143 string, Id of the callback function if found, None otherwise. 144 """ 145 # dict _functions is { id : func } 146 for id, func in _functions.items(): 147 if func is callback_func: 148 return id 149 150 def Start(self, port=0): 151 """Starts the server. 152 153 Args: 154 port: integer, number of the port on which the server listens. 155 Default is 0, which means auto-select a port available. 156 157 Returns: 158 IP Address, port number 159 160 Raises: 161 CallbackServerError is raised if the server fails to start. 162 """ 163 try: 164 self._server = socketserver.TCPServer( 165 (self._hostname, port), CallbackRequestHandler) 166 self._ip, self._port = self._server.server_address 167 168 # Start a thread with the server. 169 # Each request will be handled in a child thread. 170 server_thread = threading.Thread(target=self._server.serve_forever) 171 server_thread.daemon = True 172 server_thread.start() 173 logging.debug('TcpServer %s started (%s:%s)', server_thread.name, 174 self._ip, self._port) 175 return self._ip, self._port 176 except (RuntimeError, IOError, socket.error) as e: 177 logging.exception(e) 178 raise CallbackServerError( 179 'Failed to start CallbackServer on (%s:%s).' % 180 (self._hostname, port)) 181 182 def Stop(self): 183 """Stops the server. 184 """ 185 self._server.shutdown() 186 self._server.server_close() 187 188 @property 189 def ip(self): 190 return self._ip 191 192 @property 193 def port(self): 194 return self._port 195