1#!/usr/bin/python2.4
2#
3# Copyright 2008 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# This file is used for testing.  The original is at:
18#   http://code.google.com/p/pymox/
19
20"""Mox, an object-mocking framework for Python.
21
22Mox works in the record-replay-verify paradigm.  When you first create
23a mock object, it is in record mode.  You then programmatically set
24the expected behavior of the mock object (what methods are to be
25called on it, with what parameters, what they should return, and in
26what order).
27
28Once you have set up the expected mock behavior, you put it in replay
29mode.  Now the mock responds to method calls just as you told it to.
30If an unexpected method (or an expected method with unexpected
31parameters) is called, then an exception will be raised.
32
33Once you are done interacting with the mock, you need to verify that
34all the expected interactions occurred.  (Maybe your code exited
35prematurely without calling some cleanup method!)  The verify phase
36ensures that every expected method was called; otherwise, an exception
37will be raised.
38
39Suggested usage / workflow:
40
41  # Create Mox factory
42  my_mox = Mox()
43
44  # Create a mock data access object
45  mock_dao = my_mox.CreateMock(DAOClass)
46
47  # Set up expected behavior
48  mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49  mock_dao.DeletePerson(person)
50
51  # Put mocks in replay mode
52  my_mox.ReplayAll()
53
54  # Inject mock object and run test
55  controller.SetDao(mock_dao)
56  controller.DeletePersonById('1')
57
58  # Verify all methods were called as expected
59  my_mox.VerifyAll()
60"""
61
62from collections import deque
63import re
64import types
65import unittest
66
67import stubout
68
69class Error(AssertionError):
70  """Base exception for this module."""
71
72  pass
73
74
75class ExpectedMethodCallsError(Error):
76  """Raised when Verify() is called before all expected methods have been called
77  """
78
79  def __init__(self, expected_methods):
80    """Init exception.
81
82    Args:
83      # expected_methods: A sequence of MockMethod objects that should have been
84      #   called.
85      expected_methods: [MockMethod]
86
87    Raises:
88      ValueError: if expected_methods contains no methods.
89    """
90
91    if not expected_methods:
92      raise ValueError("There must be at least one expected method")
93    Error.__init__(self)
94    self._expected_methods = expected_methods
95
96  def __str__(self):
97    calls = "\n".join(["%3d.  %s" % (i, m)
98                       for i, m in enumerate(self._expected_methods)])
99    return "Verify: Expected methods never called:\n%s" % (calls,)
100
101
102class UnexpectedMethodCallError(Error):
103  """Raised when an unexpected method is called.
104
105  This can occur if a method is called with incorrect parameters, or out of the
106  specified order.
107  """
108
109  def __init__(self, unexpected_method, expected):
110    """Init exception.
111
112    Args:
113      # unexpected_method: MockMethod that was called but was not at the head of
114      #   the expected_method queue.
115      # expected: MockMethod or UnorderedGroup the method should have
116      #   been in.
117      unexpected_method: MockMethod
118      expected: MockMethod or UnorderedGroup
119    """
120
121    Error.__init__(self)
122    self._unexpected_method = unexpected_method
123    self._expected = expected
124
125  def __str__(self):
126    return "Unexpected method call: %s.  Expecting: %s" % \
127      (self._unexpected_method, self._expected)
128
129
130class UnknownMethodCallError(Error):
131  """Raised if an unknown method is requested of the mock object."""
132
133  def __init__(self, unknown_method_name):
134    """Init exception.
135
136    Args:
137      # unknown_method_name: Method call that is not part of the mocked class's
138      #   public interface.
139      unknown_method_name: str
140    """
141
142    Error.__init__(self)
143    self._unknown_method_name = unknown_method_name
144
145  def __str__(self):
146    return "Method called is not a member of the object: %s" % \
147      self._unknown_method_name
148
149
150class Mox(object):
151  """Mox: a factory for creating mock objects."""
152
153  # A list of types that should be stubbed out with MockObjects (as
154  # opposed to MockAnythings).
155  _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156                      types.ObjectType, types.TypeType]
157
158  def __init__(self):
159    """Initialize a new Mox."""
160
161    self._mock_objects = []
162    self.stubs = stubout.StubOutForTesting()
163
164  def CreateMock(self, class_to_mock):
165    """Create a new mock object.
166
167    Args:
168      # class_to_mock: the class to be mocked
169      class_to_mock: class
170
171    Returns:
172      MockObject that can be used as the class_to_mock would be.
173    """
174
175    new_mock = MockObject(class_to_mock)
176    self._mock_objects.append(new_mock)
177    return new_mock
178
179  def CreateMockAnything(self):
180    """Create a mock that will accept any method calls.
181
182    This does not enforce an interface.
183    """
184
185    new_mock = MockAnything()
186    self._mock_objects.append(new_mock)
187    return new_mock
188
189  def ReplayAll(self):
190    """Set all mock objects to replay mode."""
191
192    for mock_obj in self._mock_objects:
193      mock_obj._Replay()
194
195
196  def VerifyAll(self):
197    """Call verify on all mock objects created."""
198
199    for mock_obj in self._mock_objects:
200      mock_obj._Verify()
201
202  def ResetAll(self):
203    """Call reset on all mock objects.  This does not unset stubs."""
204
205    for mock_obj in self._mock_objects:
206      mock_obj._Reset()
207
208  def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209    """Replace a method, attribute, etc. with a Mock.
210
211    This will replace a class or module with a MockObject, and everything else
212    (method, function, etc) with a MockAnything.  This can be overridden to
213    always use a MockAnything by setting use_mock_anything to True.
214
215    Args:
216      obj: A Python object (class, module, instance, callable).
217      attr_name: str.  The name of the attribute to replace with a mock.
218      use_mock_anything: bool. True if a MockAnything should be used regardless
219        of the type of attribute.
220    """
221
222    attr_to_replace = getattr(obj, attr_name)
223    if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224      stub = self.CreateMock(attr_to_replace)
225    else:
226      stub = self.CreateMockAnything()
227
228    self.stubs.Set(obj, attr_name, stub)
229
230  def UnsetStubs(self):
231    """Restore stubs to their original state."""
232
233    self.stubs.UnsetAll()
234
235def Replay(*args):
236  """Put mocks into Replay mode.
237
238  Args:
239    # args is any number of mocks to put into replay mode.
240  """
241
242  for mock in args:
243    mock._Replay()
244
245
246def Verify(*args):
247  """Verify mocks.
248
249  Args:
250    # args is any number of mocks to be verified.
251  """
252
253  for mock in args:
254    mock._Verify()
255
256
257def Reset(*args):
258  """Reset mocks.
259
260  Args:
261    # args is any number of mocks to be reset.
262  """
263
264  for mock in args:
265    mock._Reset()
266
267
268class MockAnything:
269  """A mock that can be used to mock anything.
270
271  This is helpful for mocking classes that do not provide a public interface.
272  """
273
274  def __init__(self):
275    """ """
276    self._Reset()
277
278  def __getattr__(self, method_name):
279    """Intercept method calls on this object.
280
281     A new MockMethod is returned that is aware of the MockAnything's
282     state (record or replay).  The call will be recorded or replayed
283     by the MockMethod's __call__.
284
285    Args:
286      # method name: the name of the method being called.
287      method_name: str
288
289    Returns:
290      A new MockMethod aware of MockAnything's state (record or replay).
291    """
292
293    return self._CreateMockMethod(method_name)
294
295  def _CreateMockMethod(self, method_name):
296    """Create a new mock method call and return it.
297
298    Args:
299      # method name: the name of the method being called.
300      method_name: str
301
302    Returns:
303      A new MockMethod aware of MockAnything's state (record or replay).
304    """
305
306    return MockMethod(method_name, self._expected_calls_queue,
307                      self._replay_mode)
308
309  def __nonzero__(self):
310    """Return 1 for nonzero so the mock can be used as a conditional."""
311
312    return 1
313
314  def __eq__(self, rhs):
315    """Provide custom logic to compare objects."""
316
317    return (isinstance(rhs, MockAnything) and
318            self._replay_mode == rhs._replay_mode and
319            self._expected_calls_queue == rhs._expected_calls_queue)
320
321  def __ne__(self, rhs):
322    """Provide custom logic to compare objects."""
323
324    return not self == rhs
325
326  def _Replay(self):
327    """Start replaying expected method calls."""
328
329    self._replay_mode = True
330
331  def _Verify(self):
332    """Verify that all of the expected calls have been made.
333
334    Raises:
335      ExpectedMethodCallsError: if there are still more method calls in the
336        expected queue.
337    """
338
339    # If the list of expected calls is not empty, raise an exception
340    if self._expected_calls_queue:
341      # The last MultipleTimesGroup is not popped from the queue.
342      if (len(self._expected_calls_queue) == 1 and
343          isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344          self._expected_calls_queue[0].IsSatisfied()):
345        pass
346      else:
347        raise ExpectedMethodCallsError(self._expected_calls_queue)
348
349  def _Reset(self):
350    """Reset the state of this mock to record mode with an empty queue."""
351
352    # Maintain a list of method calls we are expecting
353    self._expected_calls_queue = deque()
354
355    # Make sure we are in setup mode, not replay mode
356    self._replay_mode = False
357
358
359class MockObject(MockAnything, object):
360  """A mock object that simulates the public/protected interface of a class."""
361
362  def __init__(self, class_to_mock):
363    """Initialize a mock object.
364
365    This determines the methods and properties of the class and stores them.
366
367    Args:
368      # class_to_mock: class to be mocked
369      class_to_mock: class
370    """
371
372    # This is used to hack around the mixin/inheritance of MockAnything, which
373    # is not a proper object (it can be anything. :-)
374    MockAnything.__dict__['__init__'](self)
375
376    # Get a list of all the public and special methods we should mock.
377    self._known_methods = set()
378    self._known_vars = set()
379    self._class_to_mock = class_to_mock
380    for method in dir(class_to_mock):
381      if callable(getattr(class_to_mock, method)):
382        self._known_methods.add(method)
383      else:
384        self._known_vars.add(method)
385
386  def __getattr__(self, name):
387    """Intercept attribute request on this object.
388
389    If the attribute is a public class variable, it will be returned and not
390    recorded as a call.
391
392    If the attribute is not a variable, it is handled like a method
393    call. The method name is checked against the set of mockable
394    methods, and a new MockMethod is returned that is aware of the
395    MockObject's state (record or replay).  The call will be recorded
396    or replayed by the MockMethod's __call__.
397
398    Args:
399      # name: the name of the attribute being requested.
400      name: str
401
402    Returns:
403      Either a class variable or a new MockMethod that is aware of the state
404      of the mock (record or replay).
405
406    Raises:
407      UnknownMethodCallError if the MockObject does not mock the requested
408          method.
409    """
410
411    if name in self._known_vars:
412      return getattr(self._class_to_mock, name)
413
414    if name in self._known_methods:
415      return self._CreateMockMethod(name)
416
417    raise UnknownMethodCallError(name)
418
419  def __eq__(self, rhs):
420    """Provide custom logic to compare objects."""
421
422    return (isinstance(rhs, MockObject) and
423            self._class_to_mock == rhs._class_to_mock and
424            self._replay_mode == rhs._replay_mode and
425            self._expected_calls_queue == rhs._expected_calls_queue)
426
427  def __setitem__(self, key, value):
428    """Provide custom logic for mocking classes that support item assignment.
429
430    Args:
431      key: Key to set the value for.
432      value: Value to set.
433
434    Returns:
435      Expected return value in replay mode.  A MockMethod object for the
436      __setitem__ method that has already been called if not in replay mode.
437
438    Raises:
439      TypeError if the underlying class does not support item assignment.
440      UnexpectedMethodCallError if the object does not expect the call to
441        __setitem__.
442
443    """
444    setitem = self._class_to_mock.__dict__.get('__setitem__', None)
445
446    # Verify the class supports item assignment.
447    if setitem is None:
448      raise TypeError('object does not support item assignment')
449
450    # If we are in replay mode then simply call the mock __setitem__ method.
451    if self._replay_mode:
452      return MockMethod('__setitem__', self._expected_calls_queue,
453                        self._replay_mode)(key, value)
454
455
456    # Otherwise, create a mock method __setitem__.
457    return self._CreateMockMethod('__setitem__')(key, value)
458
459  def __getitem__(self, key):
460    """Provide custom logic for mocking classes that are subscriptable.
461
462    Args:
463      key: Key to return the value for.
464
465    Returns:
466      Expected return value in replay mode.  A MockMethod object for the
467      __getitem__ method that has already been called if not in replay mode.
468
469    Raises:
470      TypeError if the underlying class is not subscriptable.
471      UnexpectedMethodCallError if the object does not expect the call to
472        __setitem__.
473
474    """
475    getitem = self._class_to_mock.__dict__.get('__getitem__', None)
476
477    # Verify the class supports item assignment.
478    if getitem is None:
479      raise TypeError('unsubscriptable object')
480
481    # If we are in replay mode then simply call the mock __getitem__ method.
482    if self._replay_mode:
483      return MockMethod('__getitem__', self._expected_calls_queue,
484                        self._replay_mode)(key)
485
486
487    # Otherwise, create a mock method __getitem__.
488    return self._CreateMockMethod('__getitem__')(key)
489
490  def __call__(self, *params, **named_params):
491    """Provide custom logic for mocking classes that are callable."""
492
493    # Verify the class we are mocking is callable
494    callable = self._class_to_mock.__dict__.get('__call__', None)
495    if callable is None:
496      raise TypeError('Not callable')
497
498    # Because the call is happening directly on this object instead of a method,
499    # the call on the mock method is made right here
500    mock_method = self._CreateMockMethod('__call__')
501    return mock_method(*params, **named_params)
502
503  @property
504  def __class__(self):
505    """Return the class that is being mocked."""
506
507    return self._class_to_mock
508
509
510class MockMethod(object):
511  """Callable mock method.
512
513  A MockMethod should act exactly like the method it mocks, accepting parameters
514  and returning a value, or throwing an exception (as specified).  When this
515  method is called, it can optionally verify whether the called method (name and
516  signature) matches the expected method.
517  """
518
519  def __init__(self, method_name, call_queue, replay_mode):
520    """Construct a new mock method.
521
522    Args:
523      # method_name: the name of the method
524      # call_queue: deque of calls, verify this call against the head, or add
525      #     this call to the queue.
526      # replay_mode: False if we are recording, True if we are verifying calls
527      #     against the call queue.
528      method_name: str
529      call_queue: list or deque
530      replay_mode: bool
531    """
532
533    self._name = method_name
534    self._call_queue = call_queue
535    if not isinstance(call_queue, deque):
536      self._call_queue = deque(self._call_queue)
537    self._replay_mode = replay_mode
538
539    self._params = None
540    self._named_params = None
541    self._return_value = None
542    self._exception = None
543    self._side_effects = None
544
545  def __call__(self, *params, **named_params):
546    """Log parameters and return the specified return value.
547
548    If the Mock(Anything/Object) associated with this call is in record mode,
549    this MockMethod will be pushed onto the expected call queue.  If the mock
550    is in replay mode, this will pop a MockMethod off the top of the queue and
551    verify this call is equal to the expected call.
552
553    Raises:
554      UnexpectedMethodCall if this call is supposed to match an expected method
555        call and it does not.
556    """
557
558    self._params = params
559    self._named_params = named_params
560
561    if not self._replay_mode:
562      self._call_queue.append(self)
563      return self
564
565    expected_method = self._VerifyMethodCall()
566
567    if expected_method._side_effects:
568      expected_method._side_effects(*params, **named_params)
569
570    if expected_method._exception:
571      raise expected_method._exception
572
573    return expected_method._return_value
574
575  def __getattr__(self, name):
576    """Raise an AttributeError with a helpful message."""
577
578    raise AttributeError('MockMethod has no attribute "%s". '
579        'Did you remember to put your mocks in replay mode?' % name)
580
581  def _PopNextMethod(self):
582    """Pop the next method from our call queue."""
583    try:
584      return self._call_queue.popleft()
585    except IndexError:
586      raise UnexpectedMethodCallError(self, None)
587
588  def _VerifyMethodCall(self):
589    """Verify the called method is expected.
590
591    This can be an ordered method, or part of an unordered set.
592
593    Returns:
594      The expected mock method.
595
596    Raises:
597      UnexpectedMethodCall if the method called was not expected.
598    """
599
600    expected = self._PopNextMethod()
601
602    # Loop here, because we might have a MethodGroup followed by another
603    # group.
604    while isinstance(expected, MethodGroup):
605      expected, method = expected.MethodCalled(self)
606      if method is not None:
607        return method
608
609    # This is a mock method, so just check equality.
610    if expected != self:
611      raise UnexpectedMethodCallError(self, expected)
612
613    return expected
614
615  def __str__(self):
616    params = ', '.join(
617        [repr(p) for p in self._params or []] +
618        ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619    desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
620    return desc
621
622  def __eq__(self, rhs):
623    """Test whether this MockMethod is equivalent to another MockMethod.
624
625    Args:
626      # rhs: the right hand side of the test
627      rhs: MockMethod
628    """
629
630    return (isinstance(rhs, MockMethod) and
631            self._name == rhs._name and
632            self._params == rhs._params and
633            self._named_params == rhs._named_params)
634
635  def __ne__(self, rhs):
636    """Test whether this MockMethod is not equivalent to another MockMethod.
637
638    Args:
639      # rhs: the right hand side of the test
640      rhs: MockMethod
641    """
642
643    return not self == rhs
644
645  def GetPossibleGroup(self):
646    """Returns a possible group from the end of the call queue or None if no
647    other methods are on the stack.
648    """
649
650    # Remove this method from the tail of the queue so we can add it to a group.
651    this_method = self._call_queue.pop()
652    assert this_method == self
653
654    # Determine if the tail of the queue is a group, or just a regular ordered
655    # mock method.
656    group = None
657    try:
658      group = self._call_queue[-1]
659    except IndexError:
660      pass
661
662    return group
663
664  def _CheckAndCreateNewGroup(self, group_name, group_class):
665    """Checks if the last method (a possible group) is an instance of our
666    group_class. Adds the current method to this group or creates a new one.
667
668    Args:
669
670      group_name: the name of the group.
671      group_class: the class used to create instance of this new group
672    """
673    group = self.GetPossibleGroup()
674
675    # If this is a group, and it is the correct group, add the method.
676    if isinstance(group, group_class) and group.group_name() == group_name:
677      group.AddMethod(self)
678      return self
679
680    # Create a new group and add the method.
681    new_group = group_class(group_name)
682    new_group.AddMethod(self)
683    self._call_queue.append(new_group)
684    return self
685
686  def InAnyOrder(self, group_name="default"):
687    """Move this method into a group of unordered calls.
688
689    A group of unordered calls must be defined together, and must be executed
690    in full before the next expected method can be called.  There can be
691    multiple groups that are expected serially, if they are given
692    different group names.  The same group name can be reused if there is a
693    standard method call, or a group with a different name, spliced between
694    usages.
695
696    Args:
697      group_name: the name of the unordered group.
698
699    Returns:
700      self
701    """
702    return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
703
704  def MultipleTimes(self, group_name="default"):
705    """Move this method into group of calls which may be called multiple times.
706
707    A group of repeating calls must be defined together, and must be executed in
708    full before the next expected mehtod can be called.
709
710    Args:
711      group_name: the name of the unordered group.
712
713    Returns:
714      self
715    """
716    return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
717
718  def AndReturn(self, return_value):
719    """Set the value to return when this method is called.
720
721    Args:
722      # return_value can be anything.
723    """
724
725    self._return_value = return_value
726    return return_value
727
728  def AndRaise(self, exception):
729    """Set the exception to raise when this method is called.
730
731    Args:
732      # exception: the exception to raise when this method is called.
733      exception: Exception
734    """
735
736    self._exception = exception
737
738  def WithSideEffects(self, side_effects):
739    """Set the side effects that are simulated when this method is called.
740
741    Args:
742      side_effects: A callable which modifies the parameters or other relevant
743        state which a given test case depends on.
744
745    Returns:
746      Self for chaining with AndReturn and AndRaise.
747    """
748    self._side_effects = side_effects
749    return self
750
751class Comparator:
752  """Base class for all Mox comparators.
753
754  A Comparator can be used as a parameter to a mocked method when the exact
755  value is not known.  For example, the code you are testing might build up a
756  long SQL string that is passed to your mock DAO. You're only interested that
757  the IN clause contains the proper primary keys, so you can set your mock
758  up as follows:
759
760  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
761
762  Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
763
764  A Comparator may replace one or more parameters, for example:
765  # return at most 10 rows
766  mock_dao.RunQuery(StrContains('SELECT'), 10)
767
768  or
769
770  # Return some non-deterministic number of rows
771  mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
772  """
773
774  def equals(self, rhs):
775    """Special equals method that all comparators must implement.
776
777    Args:
778      rhs: any python object
779    """
780
781    raise NotImplementedError, 'method must be implemented by a subclass.'
782
783  def __eq__(self, rhs):
784    return self.equals(rhs)
785
786  def __ne__(self, rhs):
787    return not self.equals(rhs)
788
789
790class IsA(Comparator):
791  """This class wraps a basic Python type or class.  It is used to verify
792  that a parameter is of the given type or class.
793
794  Example:
795  mock_dao.Connect(IsA(DbConnectInfo))
796  """
797
798  def __init__(self, class_name):
799    """Initialize IsA
800
801    Args:
802      class_name: basic python type or a class
803    """
804
805    self._class_name = class_name
806
807  def equals(self, rhs):
808    """Check to see if the RHS is an instance of class_name.
809
810    Args:
811      # rhs: the right hand side of the test
812      rhs: object
813
814    Returns:
815      bool
816    """
817
818    try:
819      return isinstance(rhs, self._class_name)
820    except TypeError:
821      # Check raw types if there was a type error.  This is helpful for
822      # things like cStringIO.StringIO.
823      return type(rhs) == type(self._class_name)
824
825  def __repr__(self):
826    return str(self._class_name)
827
828class IsAlmost(Comparator):
829  """Comparison class used to check whether a parameter is nearly equal
830  to a given value.  Generally useful for floating point numbers.
831
832  Example mock_dao.SetTimeout((IsAlmost(3.9)))
833  """
834
835  def __init__(self, float_value, places=7):
836    """Initialize IsAlmost.
837
838    Args:
839      float_value: The value for making the comparison.
840      places: The number of decimal places to round to.
841    """
842
843    self._float_value = float_value
844    self._places = places
845
846  def equals(self, rhs):
847    """Check to see if RHS is almost equal to float_value
848
849    Args:
850      rhs: the value to compare to float_value
851
852    Returns:
853      bool
854    """
855
856    try:
857      return round(rhs-self._float_value, self._places) == 0
858    except TypeError:
859      # This is probably because either float_value or rhs is not a number.
860      return False
861
862  def __repr__(self):
863    return str(self._float_value)
864
865class StrContains(Comparator):
866  """Comparison class used to check whether a substring exists in a
867  string parameter.  This can be useful in mocking a database with SQL
868  passed in as a string parameter, for example.
869
870  Example:
871  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
872  """
873
874  def __init__(self, search_string):
875    """Initialize.
876
877    Args:
878      # search_string: the string you are searching for
879      search_string: str
880    """
881
882    self._search_string = search_string
883
884  def equals(self, rhs):
885    """Check to see if the search_string is contained in the rhs string.
886
887    Args:
888      # rhs: the right hand side of the test
889      rhs: object
890
891    Returns:
892      bool
893    """
894
895    try:
896      return rhs.find(self._search_string) > -1
897    except Exception:
898      return False
899
900  def __repr__(self):
901    return '<str containing \'%s\'>' % self._search_string
902
903
904class Regex(Comparator):
905  """Checks if a string matches a regular expression.
906
907  This uses a given regular expression to determine equality.
908  """
909
910  def __init__(self, pattern, flags=0):
911    """Initialize.
912
913    Args:
914      # pattern is the regular expression to search for
915      pattern: str
916      # flags passed to re.compile function as the second argument
917      flags: int
918    """
919
920    self.regex = re.compile(pattern, flags=flags)
921
922  def equals(self, rhs):
923    """Check to see if rhs matches regular expression pattern.
924
925    Returns:
926      bool
927    """
928
929    return self.regex.search(rhs) is not None
930
931  def __repr__(self):
932    s = '<regular expression \'%s\'' % self.regex.pattern
933    if self.regex.flags:
934      s += ', flags=%d' % self.regex.flags
935    s += '>'
936    return s
937
938
939class In(Comparator):
940  """Checks whether an item (or key) is in a list (or dict) parameter.
941
942  Example:
943  mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
944  """
945
946  def __init__(self, key):
947    """Initialize.
948
949    Args:
950      # key is any thing that could be in a list or a key in a dict
951    """
952
953    self._key = key
954
955  def equals(self, rhs):
956    """Check to see whether key is in rhs.
957
958    Args:
959      rhs: dict
960
961    Returns:
962      bool
963    """
964
965    return self._key in rhs
966
967  def __repr__(self):
968    return '<sequence or map containing \'%s\'>' % self._key
969
970
971class ContainsKeyValue(Comparator):
972  """Checks whether a key/value pair is in a dict parameter.
973
974  Example:
975  mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
976  """
977
978  def __init__(self, key, value):
979    """Initialize.
980
981    Args:
982      # key: a key in a dict
983      # value: the corresponding value
984    """
985
986    self._key = key
987    self._value = value
988
989  def equals(self, rhs):
990    """Check whether the given key/value pair is in the rhs dict.
991
992    Returns:
993      bool
994    """
995
996    try:
997      return rhs[self._key] == self._value
998    except Exception:
999      return False
1000
1001  def __repr__(self):
1002    return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1003
1004
1005class SameElementsAs(Comparator):
1006  """Checks whether iterables contain the same elements (ignoring order).
1007
1008  Example:
1009  mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1010  """
1011
1012  def __init__(self, expected_seq):
1013    """Initialize.
1014
1015    Args:
1016      expected_seq: a sequence
1017    """
1018
1019    self._expected_seq = expected_seq
1020
1021  def equals(self, actual_seq):
1022    """Check to see whether actual_seq has same elements as expected_seq.
1023
1024    Args:
1025      actual_seq: sequence
1026
1027    Returns:
1028      bool
1029    """
1030
1031    try:
1032      expected = dict([(element, None) for element in self._expected_seq])
1033      actual = dict([(element, None) for element in actual_seq])
1034    except TypeError:
1035      # Fall back to slower list-compare if any of the objects are unhashable.
1036      expected = list(self._expected_seq)
1037      actual = list(actual_seq)
1038      expected.sort()
1039      actual.sort()
1040    return expected == actual
1041
1042  def __repr__(self):
1043    return '<sequence with same elements as \'%s\'>' % self._expected_seq
1044
1045
1046class And(Comparator):
1047  """Evaluates one or more Comparators on RHS and returns an AND of the results.
1048  """
1049
1050  def __init__(self, *args):
1051    """Initialize.
1052
1053    Args:
1054      *args: One or more Comparator
1055    """
1056
1057    self._comparators = args
1058
1059  def equals(self, rhs):
1060    """Checks whether all Comparators are equal to rhs.
1061
1062    Args:
1063      # rhs: can be anything
1064
1065    Returns:
1066      bool
1067    """
1068
1069    for comparator in self._comparators:
1070      if not comparator.equals(rhs):
1071        return False
1072
1073    return True
1074
1075  def __repr__(self):
1076    return '<AND %s>' % str(self._comparators)
1077
1078
1079class Or(Comparator):
1080  """Evaluates one or more Comparators on RHS and returns an OR of the results.
1081  """
1082
1083  def __init__(self, *args):
1084    """Initialize.
1085
1086    Args:
1087      *args: One or more Mox comparators
1088    """
1089
1090    self._comparators = args
1091
1092  def equals(self, rhs):
1093    """Checks whether any Comparator is equal to rhs.
1094
1095    Args:
1096      # rhs: can be anything
1097
1098    Returns:
1099      bool
1100    """
1101
1102    for comparator in self._comparators:
1103      if comparator.equals(rhs):
1104        return True
1105
1106    return False
1107
1108  def __repr__(self):
1109    return '<OR %s>' % str(self._comparators)
1110
1111
1112class Func(Comparator):
1113  """Call a function that should verify the parameter passed in is correct.
1114
1115  You may need the ability to perform more advanced operations on the parameter
1116  in order to validate it.  You can use this to have a callable validate any
1117  parameter. The callable should return either True or False.
1118
1119
1120  Example:
1121
1122  def myParamValidator(param):
1123    # Advanced logic here
1124    return True
1125
1126  mock_dao.DoSomething(Func(myParamValidator), true)
1127  """
1128
1129  def __init__(self, func):
1130    """Initialize.
1131
1132    Args:
1133      func: callable that takes one parameter and returns a bool
1134    """
1135
1136    self._func = func
1137
1138  def equals(self, rhs):
1139    """Test whether rhs passes the function test.
1140
1141    rhs is passed into func.
1142
1143    Args:
1144      rhs: any python object
1145
1146    Returns:
1147      the result of func(rhs)
1148    """
1149
1150    return self._func(rhs)
1151
1152  def __repr__(self):
1153    return str(self._func)
1154
1155
1156class IgnoreArg(Comparator):
1157  """Ignore an argument.
1158
1159  This can be used when we don't care about an argument of a method call.
1160
1161  Example:
1162  # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163  mymock.CastMagic(3, IgnoreArg(), 'disappear')
1164  """
1165
1166  def equals(self, unused_rhs):
1167    """Ignores arguments and returns True.
1168
1169    Args:
1170      unused_rhs: any python object
1171
1172    Returns:
1173      always returns True
1174    """
1175
1176    return True
1177
1178  def __repr__(self):
1179    return '<IgnoreArg>'
1180
1181
1182class MethodGroup(object):
1183  """Base class containing common behaviour for MethodGroups."""
1184
1185  def __init__(self, group_name):
1186    self._group_name = group_name
1187
1188  def group_name(self):
1189    return self._group_name
1190
1191  def __str__(self):
1192    return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1193
1194  def AddMethod(self, mock_method):
1195    raise NotImplementedError
1196
1197  def MethodCalled(self, mock_method):
1198    raise NotImplementedError
1199
1200  def IsSatisfied(self):
1201    raise NotImplementedError
1202
1203class UnorderedGroup(MethodGroup):
1204  """UnorderedGroup holds a set of method calls that may occur in any order.
1205
1206  This construct is helpful for non-deterministic events, such as iterating
1207  over the keys of a dict.
1208  """
1209
1210  def __init__(self, group_name):
1211    super(UnorderedGroup, self).__init__(group_name)
1212    self._methods = []
1213
1214  def AddMethod(self, mock_method):
1215    """Add a method to this group.
1216
1217    Args:
1218      mock_method: A mock method to be added to this group.
1219    """
1220
1221    self._methods.append(mock_method)
1222
1223  def MethodCalled(self, mock_method):
1224    """Remove a method call from the group.
1225
1226    If the method is not in the set, an UnexpectedMethodCallError will be
1227    raised.
1228
1229    Args:
1230      mock_method: a mock method that should be equal to a method in the group.
1231
1232    Returns:
1233      The mock method from the group
1234
1235    Raises:
1236      UnexpectedMethodCallError if the mock_method was not in the group.
1237    """
1238
1239    # Check to see if this method exists, and if so, remove it from the set
1240    # and return it.
1241    for method in self._methods:
1242      if method == mock_method:
1243        # Remove the called mock_method instead of the method in the group.
1244        # The called method will match any comparators when equality is checked
1245        # during removal.  The method in the group could pass a comparator to
1246        # another comparator during the equality check.
1247        self._methods.remove(mock_method)
1248
1249        # If this group is not empty, put it back at the head of the queue.
1250        if not self.IsSatisfied():
1251          mock_method._call_queue.appendleft(self)
1252
1253        return self, method
1254
1255    raise UnexpectedMethodCallError(mock_method, self)
1256
1257  def IsSatisfied(self):
1258    """Return True if there are not any methods in this group."""
1259
1260    return len(self._methods) == 0
1261
1262
1263class MultipleTimesGroup(MethodGroup):
1264  """MultipleTimesGroup holds methods that may be called any number of times.
1265
1266  Note: Each method must be called at least once.
1267
1268  This is helpful, if you don't know or care how many times a method is called.
1269  """
1270
1271  def __init__(self, group_name):
1272    super(MultipleTimesGroup, self).__init__(group_name)
1273    self._methods = set()
1274    self._methods_called = set()
1275
1276  def AddMethod(self, mock_method):
1277    """Add a method to this group.
1278
1279    Args:
1280      mock_method: A mock method to be added to this group.
1281    """
1282
1283    self._methods.add(mock_method)
1284
1285  def MethodCalled(self, mock_method):
1286    """Remove a method call from the group.
1287
1288    If the method is not in the set, an UnexpectedMethodCallError will be
1289    raised.
1290
1291    Args:
1292      mock_method: a mock method that should be equal to a method in the group.
1293
1294    Returns:
1295      The mock method from the group
1296
1297    Raises:
1298      UnexpectedMethodCallError if the mock_method was not in the group.
1299    """
1300
1301    # Check to see if this method exists, and if so add it to the set of
1302    # called methods.
1303
1304    for method in self._methods:
1305      if method == mock_method:
1306        self._methods_called.add(mock_method)
1307        # Always put this group back on top of the queue, because we don't know
1308        # when we are done.
1309        mock_method._call_queue.appendleft(self)
1310        return self, method
1311
1312    if self.IsSatisfied():
1313      next_method = mock_method._PopNextMethod();
1314      return next_method, None
1315    else:
1316      raise UnexpectedMethodCallError(mock_method, self)
1317
1318  def IsSatisfied(self):
1319    """Return True if all methods in this group are called at least once."""
1320    # NOTE(psycho): We can't use the simple set difference here because we want
1321    # to match different parameters which are considered the same e.g. IsA(str)
1322    # and some string. This solution is O(n^2) but n should be small.
1323    tmp = self._methods.copy()
1324    for called in self._methods_called:
1325      for expected in tmp:
1326        if called == expected:
1327          tmp.remove(expected)
1328          if not tmp:
1329            return True
1330          break
1331    return False
1332
1333
1334class MoxMetaTestBase(type):
1335  """Metaclass to add mox cleanup and verification to every test.
1336
1337  As the mox unit testing class is being constructed (MoxTestBase or a
1338  subclass), this metaclass will modify all test functions to call the
1339  CleanUpMox method of the test class after they finish. This means that
1340  unstubbing and verifying will happen for every test with no additional code,
1341  and any failures will result in test failures as opposed to errors.
1342  """
1343
1344  def __init__(cls, name, bases, d):
1345    type.__init__(cls, name, bases, d)
1346
1347    # also get all the attributes from the base classes to account
1348    # for a case when test class is not the immediate child of MoxTestBase
1349    for base in bases:
1350      for attr_name in dir(base):
1351        d[attr_name] = getattr(base, attr_name)
1352
1353    for func_name, func in d.items():
1354      if func_name.startswith('test') and callable(func):
1355        setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1356
1357  @staticmethod
1358  def CleanUpTest(cls, func):
1359    """Adds Mox cleanup code to any MoxTestBase method.
1360
1361    Always unsets stubs after a test. Will verify all mocks for tests that
1362    otherwise pass.
1363
1364    Args:
1365      cls: MoxTestBase or subclass; the class whose test method we are altering.
1366      func: method; the method of the MoxTestBase test class we wish to alter.
1367
1368    Returns:
1369      The modified method.
1370    """
1371    def new_method(self, *args, **kwargs):
1372      mox_obj = getattr(self, 'mox', None)
1373      cleanup_mox = False
1374      if mox_obj and isinstance(mox_obj, Mox):
1375        cleanup_mox = True
1376      try:
1377        func(self, *args, **kwargs)
1378      finally:
1379        if cleanup_mox:
1380          mox_obj.UnsetStubs()
1381      if cleanup_mox:
1382        mox_obj.VerifyAll()
1383    new_method.__name__ = func.__name__
1384    new_method.__doc__ = func.__doc__
1385    new_method.__module__ = func.__module__
1386    return new_method
1387
1388
1389class MoxTestBase(unittest.TestCase):
1390  """Convenience test class to make stubbing easier.
1391
1392  Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393  want this. Also automatically unsets any stubs and verifies that all mock
1394  methods have been called at the end of each test, eliminating boilerplate
1395  code.
1396  """
1397
1398  __metaclass__ = MoxMetaTestBase
1399
1400  def setUp(self):
1401    self.mox = Mox()
1402