1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides a pw_rpc client for Python."""
15
16import abc
17from dataclasses import dataclass
18import logging
19from typing import (Any, Collection, Dict, Iterable, Iterator, NamedTuple,
20                    Optional)
21
22from google.protobuf.message import DecodeError
23from pw_status import Status
24
25from pw_rpc import descriptors, packets
26from pw_rpc.descriptors import Channel, Service, Method
27from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
28
29_LOG = logging.getLogger(__name__)
30
31
32class Error(Exception):
33    """Error from incorrectly using the RPC client classes."""
34
35
36class PendingRpc(NamedTuple):
37    """Uniquely identifies an RPC call."""
38    channel: Channel
39    service: Service
40    method: Method
41
42    def __str__(self) -> str:
43        return f'PendingRpc(channel={self.channel.id}, method={self.method})'
44
45
46class _PendingRpcMetadata:
47    def __init__(self, context: Any, keep_open: bool):
48        self.context = context
49        self.keep_open = keep_open
50
51
52class PendingRpcs:
53    """Tracks pending RPCs and encodes outgoing RPC packets."""
54    def __init__(self):
55        self._pending: Dict[PendingRpc, _PendingRpcMetadata] = {}
56
57    def request(self,
58                rpc: PendingRpc,
59                request,
60                context,
61                override_pending: bool = True,
62                keep_open: bool = False) -> bytes:
63        """Starts the provided RPC and returns the encoded packet to send."""
64        # Ensure that every context is a unique object by wrapping it in a list.
65        self.open(rpc, context, override_pending, keep_open)
66        _LOG.debug('Starting %s', rpc)
67        return packets.encode_request(rpc, request)
68
69    def send_request(self,
70                     rpc: PendingRpc,
71                     request,
72                     context,
73                     override_pending: bool = False,
74                     keep_open: bool = False) -> None:
75        """Calls request and sends the resulting packet to the channel."""
76        # TODO(hepler): Remove `type: ignore` on this and similar lines when
77        #     https://github.com/python/mypy/issues/5485 is fixed
78        rpc.channel.output(  # type: ignore
79            self.request(rpc, request, context, override_pending, keep_open))
80
81    def open(self,
82             rpc: PendingRpc,
83             context,
84             override_pending: bool = False,
85             keep_open: bool = False) -> None:
86        """Creates a context for an RPC, but does not invoke it.
87
88        open() can be used to receive streaming responses to an RPC that was not
89        invoked by this client. For example, a server may stream logs with a
90        server streaming RPC prior to any clients invoking it.
91        """
92        metadata = _PendingRpcMetadata(context, keep_open)
93
94        if override_pending:
95            self._pending[rpc] = metadata
96        elif self._pending.setdefault(rpc, metadata) is not metadata:
97            # If the context was not added, the RPC was already pending.
98            raise Error(f'Sent request for {rpc}, but it is already pending! '
99                        'Cancel the RPC before invoking it again')
100
101    def cancel(self, rpc: PendingRpc) -> Optional[bytes]:
102        """Cancels the RPC. Returns the CANCEL packet to send.
103
104        Returns:
105          True if the RPC was cancelled; False if it was not pending
106
107        Raises:
108          KeyError if the RPC is not pending
109        """
110        _LOG.debug('Cancelling %s', rpc)
111        del self._pending[rpc]
112
113        if rpc.method.type is Method.Type.UNARY:
114            return None
115
116        return packets.encode_cancel(rpc)
117
118    def send_cancel(self, rpc: PendingRpc) -> bool:
119        """Calls cancel and sends the cancel packet, if any, to the channel."""
120        try:
121            packet = self.cancel(rpc)
122        except KeyError:
123            return False
124
125        if packet:
126            rpc.channel.output(packet)  # type: ignore
127
128        return True
129
130    def get_pending(self, rpc: PendingRpc, status: Optional[Status]):
131        """Gets the pending RPC's context. If status is set, clears the RPC."""
132        if status is None:
133            return self._pending[rpc].context
134
135        if self._pending[rpc].keep_open:
136            _LOG.debug('%s finished with status %s; keeping open', rpc, status)
137            return self._pending[rpc].context
138
139        _LOG.debug('%s finished with status %s', rpc, status)
140        return self._pending.pop(rpc).context
141
142
143class ClientImpl(abc.ABC):
144    """The internal interface of the RPC client.
145
146    This interface defines the semantics for invoking an RPC on a particular
147    client.
148    """
149    def __init__(self):
150        self.client: 'Client' = None
151        self.rpcs: PendingRpcs = None
152
153    @abc.abstractmethod
154    def method_client(self, channel: Channel, method: Method) -> Any:
155        """Returns an object that invokes a method using the given channel."""
156
157    @abc.abstractmethod
158    def handle_response(self,
159                        rpc: PendingRpc,
160                        context: Any,
161                        payload: Any,
162                        *,
163                        args: tuple = (),
164                        kwargs: dict = None) -> Any:
165        """Handles a response from the RPC server.
166
167        Args:
168          rpc: Information about the pending RPC
169          context: Arbitrary context object associated with the pending RPC
170          payload: A protobuf message
171          args, kwargs: Arbitrary arguments passed to the ClientImpl
172        """
173
174    @abc.abstractmethod
175    def handle_completion(self,
176                          rpc: PendingRpc,
177                          context: Any,
178                          status: Status,
179                          *,
180                          args: tuple = (),
181                          kwargs: dict = None) -> Any:
182        """Handles the successful completion of an RPC.
183
184        Args:
185          rpc: Information about the pending RPC
186          context: Arbitrary context object associated with the pending RPC
187          status: Status returned from the RPC
188          args, kwargs: Arbitrary arguments passed to the ClientImpl
189        """
190
191    @abc.abstractmethod
192    def handle_error(self,
193                     rpc: PendingRpc,
194                     context,
195                     status: Status,
196                     *,
197                     args: tuple = (),
198                     kwargs: dict = None):
199        """Handles the abnormal termination of an RPC.
200
201        args:
202          rpc: Information about the pending RPC
203          context: Arbitrary context object associated with the pending RPC
204          status: which error occurred
205          args, kwargs: Arbitrary arguments passed to the ClientImpl
206        """
207
208
209class ServiceClient(descriptors.ServiceAccessor):
210    """Navigates the methods in a service provided by a ChannelClient."""
211    def __init__(self, client_impl: ClientImpl, channel: Channel,
212                 service: Service):
213        super().__init__(
214            {
215                method: client_impl.method_client(channel, method)
216                for method in service.methods
217            },
218            as_attrs='members')
219
220        self._channel = channel
221        self._service = service
222
223    def __repr__(self) -> str:
224        return (f'Service({self._service.full_name!r}, '
225                f'methods={[m.name for m in self._service.methods]}, '
226                f'channel={self._channel.id})')
227
228    def __str__(self) -> str:
229        return str(self._service)
230
231
232class Services(descriptors.ServiceAccessor[ServiceClient]):
233    """Navigates the services provided by a ChannelClient."""
234    def __init__(self, client_impl, channel: Channel,
235                 services: Collection[Service]):
236        super().__init__(
237            {s: ServiceClient(client_impl, channel, s)
238             for s in services},
239            as_attrs='packages')
240
241        self._channel = channel
242        self._services = services
243
244    def __repr__(self) -> str:
245        return (f'Services(channel={self._channel.id}, '
246                f'services={[s.full_name for s in self._services]})')
247
248
249def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
250    # Server streaming RPC packets never have a status; all other packets do.
251    if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
252        return None
253
254    try:
255        return Status(packet.status)
256    except ValueError:
257        _LOG.warning('Illegal status code %d for %s', packet.status, rpc)
258
259    return None
260
261
262def _decode_payload(rpc: PendingRpc, packet):
263    if packet.type == PacketType.RESPONSE:
264        try:
265            return packets.decode_payload(packet, rpc.method.response_type)
266        except DecodeError as err:
267            _LOG.warning('Failed to decode %s response for %s: %s',
268                         rpc.method.response_type.DESCRIPTOR.full_name,
269                         rpc.method.full_name, err)
270    return None
271
272
273@dataclass(frozen=True, eq=False)
274class ChannelClient:
275    """RPC services and methods bound to a particular channel.
276
277    RPCs are invoked through service method clients. These may be accessed via
278    the `rpcs` member. Service methods use a fully qualified name: package,
279    service, method. Service methods may be selected as attributes or by
280    indexing the rpcs member by service and method name or ID.
281
282      # Access the service method client as an attribute
283      rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod
284
285      # Access the service method client by string name
286      rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod']
287
288    RPCs may also be accessed from their canonical name.
289
290      # Access the service method client from its full name:
291      rpc = client.channel(1).method('the.package.FooService/SomeMethod')
292
293      # Using a . instead of a / is also supported:
294      rpc = client.channel(1).method('the.package.FooService.SomeMethod')
295
296    The ClientImpl class determines the type of the service method client. A
297    synchronous RPC client might return a callable object, so an RPC could be
298    invoked directly (e.g. rpc(field1=123, field2=b'456')).
299    """
300    client: 'Client'
301    channel: Channel
302    rpcs: Services
303
304    def method(self, method_name: str):
305        """Returns a method client matching the given name.
306
307        Args:
308          method_name: name as package.Service/Method or package.Service.Method.
309
310        Raises:
311          ValueError: the method name is not properly formatted
312          KeyError: the method is not present
313        """
314        return descriptors.get_method(self.rpcs, method_name)
315
316    def services(self) -> Iterator:
317        return iter(self.rpcs)
318
319    def methods(self) -> Iterator:
320        """Iterates over all method clients in this ChannelClient."""
321        for service_client in self.rpcs:
322            yield from service_client
323
324    def __repr__(self) -> str:
325        return (f'ChannelClient(channel={self.channel.id}, '
326                f'services={[str(s) for s in self.services()]})')
327
328
329class Client:
330    """Sends requests and handles responses for a set of channels.
331
332    RPC invocations occur through a ChannelClient.
333    """
334    @classmethod
335    def from_modules(cls, impl: ClientImpl, channels: Iterable[Channel],
336                     modules: Iterable):
337        return cls(
338            impl, channels,
339            (Service.from_descriptor(service) for module in modules
340             for service in module.DESCRIPTOR.services_by_name.values()))
341
342    def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
343                 services: Iterable[Service]):
344        self._impl = impl
345        self._impl.client = self
346        self._impl.rpcs = PendingRpcs()
347
348        self.services = descriptors.Services(services)
349
350        self._channels_by_id = {
351            channel.id:
352            ChannelClient(self, channel,
353                          Services(self._impl, channel, self.services))
354            for channel in channels
355        }
356
357    def channel(self, channel_id: int = None) -> ChannelClient:
358        """Returns a ChannelClient, which is used to call RPCs on a channel.
359
360        If no channel is provided, the first channel is used.
361        """
362        if channel_id is None:
363            return next(iter(self._channels_by_id.values()))
364
365        return self._channels_by_id[channel_id]
366
367    def channels(self) -> Iterable[ChannelClient]:
368        """Accesses the ChannelClients in this client."""
369        return self._channels_by_id.values()
370
371    def method(self, method_name: str) -> Method:
372        """Returns a Method matching the given name.
373
374        Args:
375          method_name: name as package.Service/Method or package.Service.Method.
376
377        Raises:
378          ValueError: the method name is not properly formatted
379          KeyError: the method is not present
380        """
381        return descriptors.get_method(self.services, method_name)
382
383    def methods(self) -> Iterator[Method]:
384        """Iterates over all Methods supported by this client."""
385        for service in self.services:
386            yield from service.methods
387
388    def process_packet(self, pw_rpc_raw_packet_data: bytes, *impl_args,
389                       **impl_kwargs) -> Status:
390        """Processes an incoming packet.
391
392        Args:
393          pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet
394          impl_args: optional positional arguments passed to the ClientImpl
395          impl_kwargs: optional keyword arguments passed to the ClientImpl
396
397        Returns:
398          OK - the packet was processed by this client
399          DATA_LOSS - the packet could not be decoded
400          INVALID_ARGUMENT - the packet is for a server, not a client
401          NOT_FOUND - the packet's channel ID is not known to this client
402        """
403        try:
404            packet = packets.decode(pw_rpc_raw_packet_data)
405        except DecodeError as err:
406            _LOG.warning('Failed to decode packet: %s', err)
407            _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data)
408            return Status.DATA_LOSS
409
410        if packets.for_server(packet):
411            return Status.INVALID_ARGUMENT
412
413        try:
414            channel_client = self._channels_by_id[packet.channel_id]
415        except KeyError:
416            _LOG.warning('Unrecognized channel ID %d', packet.channel_id)
417            return Status.NOT_FOUND
418
419        try:
420            rpc = self._look_up_service_and_method(packet, channel_client)
421        except ValueError as err:
422            channel_client.channel.output(  # type: ignore
423                packets.encode_client_error(packet, Status.NOT_FOUND))
424            _LOG.warning('%s', err)
425            return Status.OK
426
427        status = _decode_status(rpc, packet)
428
429        if packet.type not in (PacketType.RESPONSE,
430                               PacketType.SERVER_STREAM_END,
431                               PacketType.SERVER_ERROR):
432            _LOG.error('%s: unexpected PacketType %s', rpc, packet.type)
433            _LOG.debug('Packet:\n%s', packet)
434            return Status.OK
435
436        payload = _decode_payload(rpc, packet)
437
438        try:
439            context = self._impl.rpcs.get_pending(rpc, status)
440        except KeyError:
441            channel_client.channel.output(  # type: ignore
442                packets.encode_client_error(packet,
443                                            Status.FAILED_PRECONDITION))
444            _LOG.debug('Discarding response for %s, which is not pending', rpc)
445            return Status.OK
446
447        if packet.type == PacketType.SERVER_ERROR:
448            assert status is not None and not status.ok()
449            _LOG.warning('%s: invocation failed with %s', rpc, status)
450            self._impl.handle_error(rpc,
451                                    context,
452                                    status,
453                                    args=impl_args,
454                                    kwargs=impl_kwargs)
455            return Status.OK
456
457        if payload is not None:
458            self._impl.handle_response(rpc,
459                                       context,
460                                       payload,
461                                       args=impl_args,
462                                       kwargs=impl_kwargs)
463        if status is not None:
464            self._impl.handle_completion(rpc,
465                                         context,
466                                         status,
467                                         args=impl_args,
468                                         kwargs=impl_kwargs)
469
470        return Status.OK
471
472    def _look_up_service_and_method(
473            self, packet: RpcPacket,
474            channel_client: ChannelClient) -> PendingRpc:
475        try:
476            service = self.services[packet.service_id]
477        except KeyError:
478            raise ValueError(f'Unrecognized service ID {packet.service_id}')
479
480        try:
481            method = service.methods[packet.method_id]
482        except KeyError:
483            raise ValueError(
484                f'No method ID {packet.method_id} in service {service.name}')
485
486        return PendingRpc(channel_client.channel, service, method)
487
488    def __repr__(self) -> str:
489        return (f'pw_rpc.Client(channels={list(self._channels_by_id)}, '
490                f'services={[s.full_name for s in self.services]})')
491