1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Utilities for using HDLC with pw_rpc."""
15
16from concurrent.futures import ThreadPoolExecutor
17import logging
18import sys
19import threading
20import time
21from typing import (Any, BinaryIO, Callable, Dict, Iterable, List, NoReturn,
22                    Optional, Union)
23
24from pw_protobuf_compiler import python_protos
25import pw_rpc
26from pw_rpc import callback_client
27
28from pw_hdlc.decode import Frame, FrameDecoder
29from pw_hdlc import encode
30
31_LOG = logging.getLogger(__name__)
32
33STDOUT_ADDRESS = 1
34DEFAULT_ADDRESS = ord('R')
35
36
37def channel_output(writer: Callable[[bytes], Any],
38                   address: int = DEFAULT_ADDRESS,
39                   delay_s: float = 0) -> Callable[[bytes], None]:
40    """Returns a function that can be used as a channel output for pw_rpc."""
41
42    if delay_s:
43
44        def slow_write(data: bytes) -> None:
45            """Slows down writes in case unbuffered serial is in use."""
46            for byte in data:
47                time.sleep(delay_s)
48                writer(bytes([byte]))
49
50        return lambda data: slow_write(encode.ui_frame(address, data))
51
52    def write_hdlc(data: bytes):
53        frame = encode.ui_frame(address, data)
54        _LOG.debug('Write %2d B: %s', len(frame), frame)
55        writer(frame)
56
57    return write_hdlc
58
59
60def _handle_error(frame: Frame) -> None:
61    _LOG.error('Failed to parse frame: %s', frame.status.value)
62    _LOG.debug('%s', frame.data)
63
64
65FrameHandlers = Dict[int, Callable[[Frame], Any]]
66
67
68def read_and_process_data(read: Callable[[], bytes],
69                          on_read_error: Callable[[Exception], Any],
70                          frame_handlers: FrameHandlers,
71                          error_handler: Callable[[Frame],
72                                                  Any] = _handle_error,
73                          handler_threads: Optional[int] = 1) -> NoReturn:
74    """Continuously reads and handles HDLC frames.
75
76    Passes frames to an executor that calls frame handler functions in other
77    threads.
78    """
79    def handle_frame(frame: Frame):
80        try:
81            if not frame.ok():
82                error_handler(frame)
83                return
84
85            try:
86                frame_handlers[frame.address](frame)
87            except KeyError:
88                _LOG.warning('Unhandled frame for address %d: %s',
89                             frame.address, frame)
90        except:  # pylint: disable=bare-except
91            _LOG.exception('Exception in HDLC frame handler thread')
92
93    decoder = FrameDecoder()
94
95    # Execute callbacks in a ThreadPoolExecutor to decouple reading the input
96    # stream from handling the data. That way, if a handler function takes a
97    # long time or crashes, this reading thread is not interrupted.
98    with ThreadPoolExecutor(max_workers=handler_threads) as executor:
99        while True:
100            try:
101                data = read()
102            except Exception as exc:  # pylint: disable=broad-except
103                on_read_error(exc)
104                continue
105
106            if data:
107                _LOG.debug('Read %2d B: %s', len(data), data)
108
109                for frame in decoder.process_valid_frames(data):
110                    executor.submit(handle_frame, frame)
111
112
113def write_to_file(data: bytes, output: BinaryIO = sys.stdout.buffer):
114    output.write(data + b'\n')
115    output.flush()
116
117
118def default_channels(write: Callable[[bytes], Any]) -> List[pw_rpc.Channel]:
119    return [pw_rpc.Channel(1, channel_output(write))]
120
121
122class HdlcRpcClient:
123    """An RPC client configured to run over HDLC."""
124    def __init__(self,
125                 read: Callable[[], bytes],
126                 paths_or_modules: Union[Iterable[python_protos.PathOrModule],
127                                         python_protos.Library],
128                 channels: Iterable[pw_rpc.Channel],
129                 output: Callable[[bytes], Any] = write_to_file,
130                 client_impl: pw_rpc.client.ClientImpl = None):
131        """Creates an RPC client configured to communicate using HDLC.
132
133        Args:
134          read: Function that reads bytes; e.g serial_device.read.
135          paths_or_modules: paths to .proto files or proto modules
136          channel: RPC channels to use for output
137          output: where to write "stdout" output from the device
138        """
139        if isinstance(paths_or_modules, python_protos.Library):
140            self.protos = paths_or_modules
141        else:
142            self.protos = python_protos.Library.from_paths(paths_or_modules)
143
144        if client_impl is None:
145            client_impl = callback_client.Impl()
146
147        self.client = pw_rpc.Client.from_modules(client_impl, channels,
148                                                 self.protos.modules())
149        frame_handlers: FrameHandlers = {
150            DEFAULT_ADDRESS: self._handle_rpc_packet,
151            STDOUT_ADDRESS: lambda frame: output(frame.data),
152        }
153
154        # Start background thread that reads and processes RPC packets.
155        threading.Thread(target=read_and_process_data,
156                         daemon=True,
157                         args=(read, lambda exc: None,
158                               frame_handlers)).start()
159
160    def rpcs(self, channel_id: int = None) -> Any:
161        """Returns object for accessing services on the specified channel.
162
163        This skips some intermediate layers to make it simpler to invoke RPCs
164        from an HdlcRpcClient. If only one channel is in use, the channel ID is
165        not necessary.
166        """
167        if channel_id is None:
168            return next(iter(self.client.channels())).rpcs
169
170        return self.client.channel(channel_id).rpcs
171
172    def _handle_rpc_packet(self, frame: Frame) -> None:
173        if not self.client.process_packet(frame.data):
174            _LOG.error('Packet not handled by RPC client: %s', frame.data)
175