1#!/usr/bin/env python3
2#
3#   Copyright 2020 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17from datetime import timedelta
18
19from mobly.asserts import assert_true
20from mobly.asserts import assert_false
21from mobly import signals
22
23from cert.event_stream import IEventStream
24from cert.event_stream import NOT_FOR_YOU_assert_event_occurs
25from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur
26from cert.event_stream import NOT_FOR_YOU_assert_none_matching
27from cert.event_stream import NOT_FOR_YOU_assert_none
28
29
30class ObjectSubject(object):
31
32    def __init__(self, value):
33        self._value = value
34
35    def isEqualTo(self, other):
36        if self._value != other:
37            raise signals.TestFailure("Expected \"%s\" to be equal to \"%s\"" % (self._value, other), extras=None)
38
39    def isNotEqualTo(self, other):
40        if self._value == other:
41            raise signals.TestFailure("Expected \"%s\" to not be equal to \"%s\"" % (self._value, other), extras=None)
42
43    def isNone(self):
44        if self._value is not None:
45            raise signals.TestFailure("Expected \"%s\" to be None" % self._value, extras=None)
46
47    def isNotNone(self):
48        if self._value is None:
49            raise signals.TestFailure("Expected \"%s\" to not be None" % self._value, extras=None)
50
51
52DEFAULT_TIMEOUT = timedelta(seconds=3)
53
54
55class EventStreamSubject(ObjectSubject):
56
57    def __init__(self, value):
58        super().__init__(value)
59
60    def emits(self, *match_fns, at_least_times=1, timeout=DEFAULT_TIMEOUT):
61        if len(match_fns) == 0:
62            raise signals.TestFailure("Must specify a match function")
63        elif len(match_fns) == 1:
64            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0], at_least_times=at_least_times, timeout=timeout)
65            return EventStreamContinuationSubject(self._value)
66        else:
67            return MultiMatchStreamSubject(self._value, match_fns, timeout)
68
69    def emitsNone(self, *match_fns, timeout=DEFAULT_TIMEOUT):
70        if len(match_fns) == 0:
71            NOT_FOR_YOU_assert_none(self._value, timeout=timeout)
72            return EventStreamContinuationSubject(self._value)
73        elif len(match_fns) == 1:
74            NOT_FOR_YOU_assert_none_matching(self._value, match_fns[0], timeout=timeout)
75            return EventStreamContinuationSubject(self._value)
76        else:
77            raise signals.TestFailure("Cannot specify multiple match functions")
78
79
80class MultiMatchStreamSubject(object):
81
82    def __init__(self, stream, match_fns, timeout):
83        self._stream = stream
84        self._match_fns = match_fns
85        self._timeout = timeout
86
87    def inAnyOrder(self):
88        NOT_FOR_YOU_assert_all_events_occur(self._stream, self._match_fns, order_matters=False, timeout=self._timeout)
89        return EventStreamContinuationSubject(self._stream)
90
91    def inOrder(self):
92        NOT_FOR_YOU_assert_all_events_occur(self._stream, self._match_fns, order_matters=True, timeout=self._timeout)
93        return EventStreamContinuationSubject(self._stream)
94
95
96class EventStreamContinuationSubject(ObjectSubject):
97
98    def __init__(self, value):
99        super().__init__(value)
100
101    def then(self, *match_fns, at_least_times=1, timeout=DEFAULT_TIMEOUT):
102        if len(match_fns) == 0:
103            raise signals.TestFailure("Must specify a match function")
104        elif len(match_fns) == 1:
105            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0], at_least_times=at_least_times, timeout=timeout)
106            return EventStreamContinuationSubject(self._value)
107        else:
108            return MultiMatchStreamSubject(self._value, match_fns, timeout)
109
110    def thenNone(self, *match_fns, timeout=DEFAULT_TIMEOUT):
111        if len(match_fns) == 0:
112            NOT_FOR_YOU_assert_none(self._value, timeout=timeout)
113            return EventStreamContinuationSubject(self._value)
114        elif len(match_fns) == 1:
115            NOT_FOR_YOU_assert_none_matching(self._value, match_fns[0], timeout=timeout)
116            return EventStreamContinuationSubject(self._value)
117        else:
118            raise signals.TestFailure("Cannot specify multiple match functions")
119
120
121class BooleanSubject(ObjectSubject):
122
123    def __init__(self, value):
124        super().__init__(value)
125
126    def isTrue(self):
127        assert_true(self._value, "")
128
129    def isFalse(self):
130        assert_false(self._value, "")
131
132
133class TimeDeltaSubject(ObjectSubject):
134
135    def __init__(self, value):
136        super().__init__(value)
137
138    def isWithin(self, time_bound):
139        assert_true(self._value < time_bound, "")
140
141
142def assertThat(subject):
143    if type(subject) is bool:
144        return BooleanSubject(subject)
145    elif isinstance(subject, IEventStream):
146        return EventStreamSubject(subject)
147    elif isinstance(subject, timedelta):
148        return TimeDeltaSubject(subject)
149    else:
150        return ObjectSubject(subject)
151