1# Copyright 2008 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# This is a fork of the pymox library intended to work with Python 3. 16# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com 17 18"""Mox, an object-mocking framework for Python. 19 20Mox works in the record-replay-verify paradigm. When you first create 21a mock object, it is in record mode. You then programmatically set 22the expected behavior of the mock object (what methods are to be 23called on it, with what parameters, what they should return, and in 24what order). 25 26Once you have set up the expected mock behavior, you put it in replay 27mode. Now the mock responds to method calls just as you told it to. 28If an unexpected method (or an expected method with unexpected 29parameters) is called, then an exception will be raised. 30 31Once you are done interacting with the mock, you need to verify that 32all the expected interactions occured. (Maybe your code exited 33prematurely without calling some cleanup method!) The verify phase 34ensures that every expected method was called; otherwise, an exception 35will be raised. 36 37WARNING! Mock objects created by Mox are not thread-safe. If you are 38call a mock in multiple threads, it should be guarded by a mutex. 39 40TODO(stevepm): Add the option to make mocks thread-safe! 41 42Suggested usage / workflow: 43 44 # Create Mox factory 45 my_mox = Mox() 46 47 # Create a mock data access object 48 mock_dao = my_mox.CreateMock(DAOClass) 49 50 # Set up expected behavior 51 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) 52 mock_dao.DeletePerson(person) 53 54 # Put mocks in replay mode 55 my_mox.ReplayAll() 56 57 # Inject mock object and run test 58 controller.SetDao(mock_dao) 59 controller.DeletePersonById('1') 60 61 # Verify all methods were called as expected 62 my_mox.VerifyAll() 63""" 64 65import collections 66import difflib 67import inspect 68import re 69import types 70import unittest 71 72from mox3 import stubout 73 74 75class Error(AssertionError): 76 """Base exception for this module.""" 77 78 pass 79 80 81class ExpectedMethodCallsError(Error): 82 """Raised when an expected method wasn't called. 83 84 This can occur if Verify() is called before all expected methods have been 85 called. 86 """ 87 88 def __init__(self, expected_methods): 89 """Init exception. 90 91 Args: 92 # expected_methods: A sequence of MockMethod objects that should 93 # have been called. 94 expected_methods: [MockMethod] 95 96 Raises: 97 ValueError: if expected_methods contains no methods. 98 """ 99 100 if not expected_methods: 101 raise ValueError("There must be at least one expected method") 102 Error.__init__(self) 103 self._expected_methods = expected_methods 104 105 def __str__(self): 106 calls = "\n".join(["%3d. %s" % (i, m) 107 for i, m in enumerate(self._expected_methods)]) 108 return "Verify: Expected methods never called:\n%s" % (calls,) 109 110 111class UnexpectedMethodCallError(Error): 112 """Raised when an unexpected method is called. 113 114 This can occur if a method is called with incorrect parameters, or out of 115 the specified order. 116 """ 117 118 def __init__(self, unexpected_method, expected): 119 """Init exception. 120 121 Args: 122 # unexpected_method: MockMethod that was called but was not at the 123 # head of the expected_method queue. 124 # expected: MockMethod or UnorderedGroup the method should have 125 # been in. 126 unexpected_method: MockMethod 127 expected: MockMethod or UnorderedGroup 128 """ 129 130 Error.__init__(self) 131 if expected is None: 132 self._str = "Unexpected method call %s" % (unexpected_method,) 133 else: 134 differ = difflib.Differ() 135 diff = differ.compare(str(unexpected_method).splitlines(True), 136 str(expected).splitlines(True)) 137 self._str = ("Unexpected method call." 138 " unexpected:- expected:+\n%s" 139 % ("\n".join(line.rstrip() for line in diff),)) 140 141 def __str__(self): 142 return self._str 143 144 145class UnknownMethodCallError(Error): 146 """Raised if an unknown method is requested of the mock object.""" 147 148 def __init__(self, unknown_method_name): 149 """Init exception. 150 151 Args: 152 # unknown_method_name: Method call that is not part of the mocked 153 # class's public interface. 154 unknown_method_name: str 155 """ 156 157 Error.__init__(self) 158 self._unknown_method_name = unknown_method_name 159 160 def __str__(self): 161 return ("Method called is not a member of the object: %s" % 162 self._unknown_method_name) 163 164 165class PrivateAttributeError(Error): 166 """Raised if a MockObject is passed a private additional attribute name.""" 167 168 def __init__(self, attr): 169 Error.__init__(self) 170 self._attr = attr 171 172 def __str__(self): 173 return ("Attribute '%s' is private and should not be available" 174 "in a mock object." % self._attr) 175 176 177class ExpectedMockCreationError(Error): 178 """Raised if mocks should have been created by StubOutClassWithMocks.""" 179 180 def __init__(self, expected_mocks): 181 """Init exception. 182 183 Args: 184 # expected_mocks: A sequence of MockObjects that should have been 185 # created 186 187 Raises: 188 ValueError: if expected_mocks contains no methods. 189 """ 190 191 if not expected_mocks: 192 raise ValueError("There must be at least one expected method") 193 Error.__init__(self) 194 self._expected_mocks = expected_mocks 195 196 def __str__(self): 197 mocks = "\n".join(["%3d. %s" % (i, m) 198 for i, m in enumerate(self._expected_mocks)]) 199 return "Verify: Expected mocks never created:\n%s" % (mocks,) 200 201 202class UnexpectedMockCreationError(Error): 203 """Raised if too many mocks were created by StubOutClassWithMocks.""" 204 205 def __init__(self, instance, *params, **named_params): 206 """Init exception. 207 208 Args: 209 # instance: the type of obejct that was created 210 # params: parameters given during instantiation 211 # named_params: named parameters given during instantiation 212 """ 213 214 Error.__init__(self) 215 self._instance = instance 216 self._params = params 217 self._named_params = named_params 218 219 def __str__(self): 220 args = ", ".join(["%s" % v for i, v in enumerate(self._params)]) 221 error = "Unexpected mock creation: %s(%s" % (self._instance, args) 222 223 if self._named_params: 224 error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in 225 self._named_params.items()]) 226 227 error += ")" 228 return error 229 230 231class Mox(object): 232 """Mox: a factory for creating mock objects.""" 233 234 # A list of types that should be stubbed out with MockObjects (as 235 # opposed to MockAnythings). 236 _USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType] 237 238 def __init__(self): 239 """Initialize a new Mox.""" 240 241 self._mock_objects = [] 242 self.stubs = stubout.StubOutForTesting() 243 244 def CreateMock(self, class_to_mock, attrs=None, bounded_to=None): 245 """Create a new mock object. 246 247 Args: 248 # class_to_mock: the class to be mocked 249 class_to_mock: class 250 attrs: dict of attribute names to values that will be 251 set on the mock object. Only public attributes may be set. 252 bounded_to: optionally, when class_to_mock is not a class, 253 it points to a real class object, to which 254 attribute is bound 255 256 Returns: 257 MockObject that can be used as the class_to_mock would be. 258 """ 259 if attrs is None: 260 attrs = {} 261 new_mock = MockObject(class_to_mock, attrs=attrs, 262 class_to_bind=bounded_to) 263 self._mock_objects.append(new_mock) 264 return new_mock 265 266 def CreateMockAnything(self, description=None): 267 """Create a mock that will accept any method calls. 268 269 This does not enforce an interface. 270 271 Args: 272 description: str. Optionally, a descriptive name for the mock object 273 being created, for debugging output purposes. 274 """ 275 new_mock = MockAnything(description=description) 276 self._mock_objects.append(new_mock) 277 return new_mock 278 279 def ReplayAll(self): 280 """Set all mock objects to replay mode.""" 281 282 for mock_obj in self._mock_objects: 283 mock_obj._Replay() 284 285 def VerifyAll(self): 286 """Call verify on all mock objects created.""" 287 288 for mock_obj in self._mock_objects: 289 mock_obj._Verify() 290 291 def ResetAll(self): 292 """Call reset on all mock objects. This does not unset stubs.""" 293 294 for mock_obj in self._mock_objects: 295 mock_obj._Reset() 296 297 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): 298 """Replace a method, attribute, etc. with a Mock. 299 300 This will replace a class or module with a MockObject, and everything 301 else (method, function, etc) with a MockAnything. This can be 302 overridden to always use a MockAnything by setting use_mock_anything 303 to True. 304 305 Args: 306 obj: A Python object (class, module, instance, callable). 307 attr_name: str. The name of the attribute to replace with a mock. 308 use_mock_anything: bool. True if a MockAnything should be used 309 regardless of the type of attribute. 310 """ 311 312 if inspect.isclass(obj): 313 class_to_bind = obj 314 else: 315 class_to_bind = None 316 317 attr_to_replace = getattr(obj, attr_name) 318 attr_type = type(attr_to_replace) 319 320 if attr_type == MockAnything or attr_type == MockObject: 321 raise TypeError('Cannot mock a MockAnything! Did you remember to ' 322 'call UnsetStubs in your previous test?') 323 324 type_check = ( 325 attr_type in self._USE_MOCK_OBJECT or 326 inspect.isclass(attr_to_replace) or 327 isinstance(attr_to_replace, object)) 328 if type_check and not use_mock_anything: 329 stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind) 330 else: 331 stub = self.CreateMockAnything( 332 description='Stub for %s' % attr_to_replace) 333 stub.__name__ = attr_name 334 335 self.stubs.Set(obj, attr_name, stub) 336 337 def StubOutClassWithMocks(self, obj, attr_name): 338 """Replace a class with a "mock factory" that will create mock objects. 339 340 This is useful if the code-under-test directly instantiates 341 dependencies. Previously some boilder plate was necessary to 342 create a mock that would act as a factory. Using 343 StubOutClassWithMocks, once you've stubbed out the class you may 344 use the stubbed class as you would any other mock created by mox: 345 during the record phase, new mock instances will be created, and 346 during replay, the recorded mocks will be returned. 347 348 In replay mode 349 350 # Example using StubOutWithMock (the old, clunky way): 351 352 mock1 = mox.CreateMock(my_import.FooClass) 353 mock2 = mox.CreateMock(my_import.FooClass) 354 foo_factory = mox.StubOutWithMock(my_import, 'FooClass', 355 use_mock_anything=True) 356 foo_factory(1, 2).AndReturn(mock1) 357 foo_factory(9, 10).AndReturn(mock2) 358 mox.ReplayAll() 359 360 my_import.FooClass(1, 2) # Returns mock1 again. 361 my_import.FooClass(9, 10) # Returns mock2 again. 362 mox.VerifyAll() 363 364 # Example using StubOutClassWithMocks: 365 366 mox.StubOutClassWithMocks(my_import, 'FooClass') 367 mock1 = my_import.FooClass(1, 2) # Returns a new mock of FooClass 368 mock2 = my_import.FooClass(9, 10) # Returns another mock instance 369 mox.ReplayAll() 370 371 my_import.FooClass(1, 2) # Returns mock1 again. 372 my_import.FooClass(9, 10) # Returns mock2 again. 373 mox.VerifyAll() 374 """ 375 attr_to_replace = getattr(obj, attr_name) 376 attr_type = type(attr_to_replace) 377 378 if attr_type == MockAnything or attr_type == MockObject: 379 raise TypeError('Cannot mock a MockAnything! Did you remember to ' 380 'call UnsetStubs in your previous test?') 381 382 if not inspect.isclass(attr_to_replace): 383 raise TypeError('Given attr is not a Class. Use StubOutWithMock.') 384 385 factory = _MockObjectFactory(attr_to_replace, self) 386 self._mock_objects.append(factory) 387 self.stubs.Set(obj, attr_name, factory) 388 389 def UnsetStubs(self): 390 """Restore stubs to their original state.""" 391 392 self.stubs.UnsetAll() 393 394 395def Replay(*args): 396 """Put mocks into Replay mode. 397 398 Args: 399 # args is any number of mocks to put into replay mode. 400 """ 401 402 for mock in args: 403 mock._Replay() 404 405 406def Verify(*args): 407 """Verify mocks. 408 409 Args: 410 # args is any number of mocks to be verified. 411 """ 412 413 for mock in args: 414 mock._Verify() 415 416 417def Reset(*args): 418 """Reset mocks. 419 420 Args: 421 # args is any number of mocks to be reset. 422 """ 423 424 for mock in args: 425 mock._Reset() 426 427 428class MockAnything(object): 429 """A mock that can be used to mock anything. 430 431 This is helpful for mocking classes that do not provide a public interface. 432 """ 433 434 def __init__(self, description=None): 435 """Initialize a new MockAnything. 436 437 Args: 438 description: str. Optionally, a descriptive name for the mock 439 object being created, for debugging output purposes. 440 """ 441 self._description = description 442 self._Reset() 443 444 def __repr__(self): 445 if self._description: 446 return '<MockAnything instance of %s>' % self._description 447 else: 448 return '<MockAnything instance>' 449 450 def __getattr__(self, method_name): 451 """Intercept method calls on this object. 452 453 A new MockMethod is returned that is aware of the MockAnything's 454 state (record or replay). The call will be recorded or replayed 455 by the MockMethod's __call__. 456 457 Args: 458 # method name: the name of the method being called. 459 method_name: str 460 461 Returns: 462 A new MockMethod aware of MockAnything's state (record or replay). 463 """ 464 if method_name == '__dir__': 465 return self.__class__.__dir__.__get__(self, self.__class__) 466 467 return self._CreateMockMethod(method_name) 468 469 def __str__(self): 470 return self._CreateMockMethod('__str__')() 471 472 def __call__(self, *args, **kwargs): 473 return self._CreateMockMethod('__call__')(*args, **kwargs) 474 475 def __getitem__(self, i): 476 return self._CreateMockMethod('__getitem__')(i) 477 478 def _CreateMockMethod(self, method_name, method_to_mock=None, 479 class_to_bind=object): 480 """Create a new mock method call and return it. 481 482 Args: 483 # method_name: the name of the method being called. 484 # method_to_mock: The actual method being mocked, used for 485 # introspection. 486 # class_to_bind: Class to which method is bounded 487 # (object by default) 488 method_name: str 489 method_to_mock: a method object 490 491 Returns: 492 A new MockMethod aware of MockAnything's state (record or replay). 493 """ 494 495 return MockMethod(method_name, self._expected_calls_queue, 496 self._replay_mode, method_to_mock=method_to_mock, 497 description=self._description, 498 class_to_bind=class_to_bind) 499 500 def __nonzero__(self): 501 """Return 1 for nonzero so the mock can be used as a conditional.""" 502 503 return 1 504 505 def __bool__(self): 506 """Return True for nonzero so the mock can be used as a conditional.""" 507 return True 508 509 def __eq__(self, rhs): 510 """Provide custom logic to compare objects.""" 511 512 return (isinstance(rhs, MockAnything) and 513 self._replay_mode == rhs._replay_mode and 514 self._expected_calls_queue == rhs._expected_calls_queue) 515 516 def __ne__(self, rhs): 517 """Provide custom logic to compare objects.""" 518 519 return not self == rhs 520 521 def _Replay(self): 522 """Start replaying expected method calls.""" 523 524 self._replay_mode = True 525 526 def _Verify(self): 527 """Verify that all of the expected calls have been made. 528 529 Raises: 530 ExpectedMethodCallsError: if there are still more method calls in 531 the expected queue. 532 """ 533 534 # If the list of expected calls is not empty, raise an exception 535 if self._expected_calls_queue: 536 # The last MultipleTimesGroup is not popped from the queue. 537 if (len(self._expected_calls_queue) == 1 and 538 isinstance(self._expected_calls_queue[0], 539 MultipleTimesGroup) and 540 self._expected_calls_queue[0].IsSatisfied()): 541 pass 542 else: 543 raise ExpectedMethodCallsError(self._expected_calls_queue) 544 545 def _Reset(self): 546 """Reset the state of this mock to record mode with an empty queue.""" 547 548 # Maintain a list of method calls we are expecting 549 self._expected_calls_queue = collections.deque() 550 551 # Make sure we are in setup mode, not replay mode 552 self._replay_mode = False 553 554 555class MockObject(MockAnything): 556 """Mock object that simulates the public/protected interface of a class.""" 557 558 def __init__(self, class_to_mock, attrs=None, class_to_bind=None): 559 """Initialize a mock object. 560 561 Determines the methods and properties of the class and stores them. 562 563 Args: 564 # class_to_mock: class to be mocked 565 class_to_mock: class 566 attrs: dict of attribute names to values that will be set on the 567 mock object. Only public attributes may be set. 568 class_to_bind: optionally, when class_to_mock is not a class at 569 all, it points to a real class 570 571 Raises: 572 PrivateAttributeError: if a supplied attribute is not public. 573 ValueError: if an attribute would mask an existing method. 574 """ 575 if attrs is None: 576 attrs = {} 577 578 # Used to hack around the mixin/inheritance of MockAnything, which 579 # is not a proper object (it can be anything. :-) 580 MockAnything.__dict__['__init__'](self) 581 582 # Get a list of all the public and special methods we should mock. 583 self._known_methods = set() 584 self._known_vars = set() 585 self._class_to_mock = class_to_mock 586 587 if inspect.isclass(class_to_mock): 588 self._class_to_bind = self._class_to_mock 589 else: 590 self._class_to_bind = class_to_bind 591 592 try: 593 if inspect.isclass(self._class_to_mock): 594 self._description = class_to_mock.__name__ 595 else: 596 self._description = type(class_to_mock).__name__ 597 except Exception: 598 pass 599 600 for method in dir(class_to_mock): 601 attr = getattr(class_to_mock, method) 602 if callable(attr): 603 self._known_methods.add(method) 604 elif not (type(attr) is property): 605 # treating properties as class vars makes little sense. 606 self._known_vars.add(method) 607 608 # Set additional attributes at instantiation time; this is quicker 609 # than manually setting attributes that are normally created in 610 # __init__. 611 for attr, value in attrs.items(): 612 if attr.startswith("_"): 613 raise PrivateAttributeError(attr) 614 elif attr in self._known_methods: 615 raise ValueError("'%s' is a method of '%s' objects." % (attr, 616 class_to_mock)) 617 else: 618 setattr(self, attr, value) 619 620 def _CreateMockMethod(self, *args, **kwargs): 621 """Overridden to provide self._class_to_mock to class_to_bind.""" 622 kwargs.setdefault("class_to_bind", self._class_to_bind) 623 return super(MockObject, self)._CreateMockMethod(*args, **kwargs) 624 625 def __getattr__(self, name): 626 """Intercept attribute request on this object. 627 628 If the attribute is a public class variable, it will be returned and 629 not recorded as a call. 630 631 If the attribute is not a variable, it is handled like a method 632 call. The method name is checked against the set of mockable 633 methods, and a new MockMethod is returned that is aware of the 634 MockObject's state (record or replay). The call will be recorded 635 or replayed by the MockMethod's __call__. 636 637 Args: 638 # name: the name of the attribute being requested. 639 name: str 640 641 Returns: 642 Either a class variable or a new MockMethod that is aware of the 643 state of the mock (record or replay). 644 645 Raises: 646 UnknownMethodCallError if the MockObject does not mock the 647 requested method. 648 """ 649 650 if name in self._known_vars: 651 return getattr(self._class_to_mock, name) 652 653 if name in self._known_methods: 654 return self._CreateMockMethod( 655 name, 656 method_to_mock=getattr(self._class_to_mock, name)) 657 658 raise UnknownMethodCallError(name) 659 660 def __eq__(self, rhs): 661 """Provide custom logic to compare objects.""" 662 663 return (isinstance(rhs, MockObject) and 664 self._class_to_mock == rhs._class_to_mock and 665 self._replay_mode == rhs._replay_mode and 666 self._expected_calls_queue == rhs._expected_calls_queue) 667 668 def __setitem__(self, key, value): 669 """Custom logic for mocking classes that support item assignment. 670 671 Args: 672 key: Key to set the value for. 673 value: Value to set. 674 675 Returns: 676 Expected return value in replay mode. A MockMethod object for the 677 __setitem__ method that has already been called if not in replay 678 mode. 679 680 Raises: 681 TypeError if the underlying class does not support item assignment. 682 UnexpectedMethodCallError if the object does not expect the call to 683 __setitem__. 684 685 """ 686 # Verify the class supports item assignment. 687 if '__setitem__' not in dir(self._class_to_mock): 688 raise TypeError('object does not support item assignment') 689 690 # If we are in replay mode then simply call the mock __setitem__ method 691 if self._replay_mode: 692 return MockMethod('__setitem__', self._expected_calls_queue, 693 self._replay_mode)(key, value) 694 695 # Otherwise, create a mock method __setitem__. 696 return self._CreateMockMethod('__setitem__')(key, value) 697 698 def __getitem__(self, key): 699 """Provide custom logic for mocking classes that are subscriptable. 700 701 Args: 702 key: Key to return the value for. 703 704 Returns: 705 Expected return value in replay mode. A MockMethod object for the 706 __getitem__ method that has already been called if not in replay 707 mode. 708 709 Raises: 710 TypeError if the underlying class is not subscriptable. 711 UnexpectedMethodCallError if the object does not expect the call to 712 __getitem__. 713 714 """ 715 # Verify the class supports item assignment. 716 if '__getitem__' not in dir(self._class_to_mock): 717 raise TypeError('unsubscriptable object') 718 719 # If we are in replay mode then simply call the mock __getitem__ method 720 if self._replay_mode: 721 return MockMethod('__getitem__', self._expected_calls_queue, 722 self._replay_mode)(key) 723 724 # Otherwise, create a mock method __getitem__. 725 return self._CreateMockMethod('__getitem__')(key) 726 727 def __iter__(self): 728 """Provide custom logic for mocking classes that are iterable. 729 730 Returns: 731 Expected return value in replay mode. A MockMethod object for the 732 __iter__ method that has already been called if not in replay mode. 733 734 Raises: 735 TypeError if the underlying class is not iterable. 736 UnexpectedMethodCallError if the object does not expect the call to 737 __iter__. 738 739 """ 740 methods = dir(self._class_to_mock) 741 742 # Verify the class supports iteration. 743 if '__iter__' not in methods: 744 # If it doesn't have iter method and we are in replay method, 745 # then try to iterate using subscripts. 746 if '__getitem__' not in methods or not self._replay_mode: 747 raise TypeError('not iterable object') 748 else: 749 results = [] 750 index = 0 751 try: 752 while True: 753 results.append(self[index]) 754 index += 1 755 except IndexError: 756 return iter(results) 757 758 # If we are in replay mode then simply call the mock __iter__ method. 759 if self._replay_mode: 760 return MockMethod('__iter__', self._expected_calls_queue, 761 self._replay_mode)() 762 763 # Otherwise, create a mock method __iter__. 764 return self._CreateMockMethod('__iter__')() 765 766 def __contains__(self, key): 767 """Provide custom logic for mocking classes that contain items. 768 769 Args: 770 key: Key to look in container for. 771 772 Returns: 773 Expected return value in replay mode. A MockMethod object for the 774 __contains__ method that has already been called if not in replay 775 mode. 776 777 Raises: 778 TypeError if the underlying class does not implement __contains__ 779 UnexpectedMethodCaller if the object does not expect the call to 780 __contains__. 781 782 """ 783 contains = self._class_to_mock.__dict__.get('__contains__', None) 784 785 if contains is None: 786 raise TypeError('unsubscriptable object') 787 788 if self._replay_mode: 789 return MockMethod('__contains__', self._expected_calls_queue, 790 self._replay_mode)(key) 791 792 return self._CreateMockMethod('__contains__')(key) 793 794 def __call__(self, *params, **named_params): 795 """Provide custom logic for mocking classes that are callable.""" 796 797 # Verify the class we are mocking is callable. 798 is_callable = hasattr(self._class_to_mock, '__call__') 799 if not is_callable: 800 raise TypeError('Not callable') 801 802 # Because the call is happening directly on this object instead of 803 # a method, the call on the mock method is made right here 804 805 # If we are mocking a Function, then use the function, and not the 806 # __call__ method 807 method = None 808 if type(self._class_to_mock) in (types.FunctionType, types.MethodType): 809 method = self._class_to_mock 810 else: 811 method = getattr(self._class_to_mock, '__call__') 812 mock_method = self._CreateMockMethod('__call__', method_to_mock=method) 813 814 return mock_method(*params, **named_params) 815 816 @property 817 def __name__(self): 818 """Return the name that is being mocked.""" 819 return self._description 820 821 # TODO(dejw): this property stopped to work after I introduced changes with 822 # binding classes. Fortunately I found a solution in the form of 823 # __getattribute__ method below, but this issue should be investigated 824 @property 825 def __class__(self): 826 return self._class_to_mock 827 828 def __dir__(self): 829 """Return only attributes of a class to mock.""" 830 return dir(self._class_to_mock) 831 832 def __getattribute__(self, name): 833 """Return _class_to_mock on __class__ attribute.""" 834 if name == "__class__": 835 return super(MockObject, self).__getattribute__("_class_to_mock") 836 837 return super(MockObject, self).__getattribute__(name) 838 839 840class _MockObjectFactory(MockObject): 841 """A MockObjectFactory creates mocks and verifies __init__ params. 842 843 A MockObjectFactory removes the boiler plate code that was previously 844 necessary to stub out direction instantiation of a class. 845 846 The MockObjectFactory creates new MockObjects when called and verifies the 847 __init__ params are correct when in record mode. When replaying, 848 existing mocks are returned, and the __init__ params are verified. 849 850 See StubOutWithMock vs StubOutClassWithMocks for more detail. 851 """ 852 853 def __init__(self, class_to_mock, mox_instance): 854 MockObject.__init__(self, class_to_mock) 855 self._mox = mox_instance 856 self._instance_queue = collections.deque() 857 858 def __call__(self, *params, **named_params): 859 """Instantiate and record that a new mock has been created.""" 860 861 method = getattr(self._class_to_mock, '__init__') 862 mock_method = self._CreateMockMethod('__init__', method_to_mock=method) 863 # Note: calling mock_method() is deferred in order to catch the 864 # empty instance_queue first. 865 866 if self._replay_mode: 867 if not self._instance_queue: 868 raise UnexpectedMockCreationError(self._class_to_mock, *params, 869 **named_params) 870 871 mock_method(*params, **named_params) 872 873 return self._instance_queue.pop() 874 else: 875 mock_method(*params, **named_params) 876 877 instance = self._mox.CreateMock(self._class_to_mock) 878 self._instance_queue.appendleft(instance) 879 return instance 880 881 def _Verify(self): 882 """Verify that all mocks have been created.""" 883 if self._instance_queue: 884 raise ExpectedMockCreationError(self._instance_queue) 885 super(_MockObjectFactory, self)._Verify() 886 887 888class MethodSignatureChecker(object): 889 """Ensures that methods are called correctly.""" 890 891 _NEEDED, _DEFAULT, _GIVEN = range(3) 892 893 def __init__(self, method, class_to_bind=None): 894 """Creates a checker. 895 896 Args: 897 # method: A method to check. 898 # class_to_bind: optionally, a class used to type check first 899 # method parameter, only used with unbound methods 900 method: function 901 class_to_bind: type or None 902 903 Raises: 904 ValueError: method could not be inspected, so checks aren't 905 possible. Some methods and functions like built-ins 906 can't be inspected. 907 """ 908 try: 909 self._args, varargs, varkw, defaults = inspect.getargspec(method) 910 except TypeError: 911 raise ValueError('Could not get argument specification for %r' 912 % (method,)) 913 if inspect.ismethod(method) or class_to_bind: 914 self._args = self._args[1:] # Skip 'self'. 915 self._method = method 916 self._instance = None # May contain the instance this is bound to. 917 self._instance = getattr(method, "__self__", None) 918 919 # _bounded_to determines whether the method is bound or not 920 if self._instance: 921 self._bounded_to = self._instance.__class__ 922 else: 923 self._bounded_to = class_to_bind or getattr(method, "im_class", 924 None) 925 926 self._has_varargs = varargs is not None 927 self._has_varkw = varkw is not None 928 if defaults is None: 929 self._required_args = self._args 930 self._default_args = [] 931 else: 932 self._required_args = self._args[:-len(defaults)] 933 self._default_args = self._args[-len(defaults):] 934 935 def _RecordArgumentGiven(self, arg_name, arg_status): 936 """Mark an argument as being given. 937 938 Args: 939 # arg_name: The name of the argument to mark in arg_status. 940 # arg_status: Maps argument names to one of 941 # _NEEDED, _DEFAULT, _GIVEN. 942 arg_name: string 943 arg_status: dict 944 945 Raises: 946 AttributeError: arg_name is already marked as _GIVEN. 947 """ 948 if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN: 949 raise AttributeError('%s provided more than once' % (arg_name,)) 950 arg_status[arg_name] = MethodSignatureChecker._GIVEN 951 952 def Check(self, params, named_params): 953 """Ensures that the parameters used while recording a call are valid. 954 955 Args: 956 # params: A list of positional parameters. 957 # named_params: A dict of named parameters. 958 params: list 959 named_params: dict 960 961 Raises: 962 AttributeError: the given parameters don't work with the given 963 method. 964 """ 965 arg_status = dict((a, MethodSignatureChecker._NEEDED) 966 for a in self._required_args) 967 for arg in self._default_args: 968 arg_status[arg] = MethodSignatureChecker._DEFAULT 969 970 # WARNING: Suspect hack ahead. 971 # 972 # Check to see if this is an unbound method, where the instance 973 # should be bound as the first argument. We try to determine if 974 # the first argument (param[0]) is an instance of the class, or it 975 # is equivalent to the class (used to account for Comparators). 976 # 977 # NOTE: If a Func() comparator is used, and the signature is not 978 # correct, this will cause extra executions of the function. 979 if inspect.ismethod(self._method) or self._bounded_to: 980 # The extra param accounts for the bound instance. 981 if len(params) > len(self._required_args): 982 expected = self._bounded_to 983 984 # Check if the param is an instance of the expected class, 985 # or check equality (useful for checking Comparators). 986 987 # This is a hack to work around the fact that the first 988 # parameter can be a Comparator, and the comparison may raise 989 # an exception during this comparison, which is OK. 990 try: 991 param_equality = (params[0] == expected) 992 except Exception: 993 param_equality = False 994 995 if isinstance(params[0], expected) or param_equality: 996 params = params[1:] 997 # If the IsA() comparator is being used, we need to check the 998 # inverse of the usual case - that the given instance is a 999 # subclass of the expected class. For example, the code under 1000 # test does late binding to a subclass. 1001 elif (isinstance(params[0], IsA) and 1002 params[0]._IsSubClass(expected)): 1003 params = params[1:] 1004 1005 # Check that each positional param is valid. 1006 for i in range(len(params)): 1007 try: 1008 arg_name = self._args[i] 1009 except IndexError: 1010 if not self._has_varargs: 1011 raise AttributeError( 1012 '%s does not take %d or more positional ' 1013 'arguments' % (self._method.__name__, i)) 1014 else: 1015 self._RecordArgumentGiven(arg_name, arg_status) 1016 1017 # Check each keyword argument. 1018 for arg_name in named_params: 1019 if arg_name not in arg_status and not self._has_varkw: 1020 raise AttributeError('%s is not expecting keyword argument %s' 1021 % (self._method.__name__, arg_name)) 1022 self._RecordArgumentGiven(arg_name, arg_status) 1023 1024 # Ensure all the required arguments have been given. 1025 still_needed = [k for k, v in arg_status.items() 1026 if v == MethodSignatureChecker._NEEDED] 1027 if still_needed: 1028 raise AttributeError('No values given for arguments: %s' 1029 % (' '.join(sorted(still_needed)))) 1030 1031 1032class MockMethod(object): 1033 """Callable mock method. 1034 1035 A MockMethod should act exactly like the method it mocks, accepting 1036 parameters and returning a value, or throwing an exception (as specified). 1037 When this method is called, it can optionally verify whether the called 1038 method (name and signature) matches the expected method. 1039 """ 1040 1041 def __init__(self, method_name, call_queue, replay_mode, 1042 method_to_mock=None, description=None, class_to_bind=None): 1043 """Construct a new mock method. 1044 1045 Args: 1046 # method_name: the name of the method 1047 # call_queue: deque of calls, verify this call against the head, 1048 # or add this call to the queue. 1049 # replay_mode: False if we are recording, True if we are verifying 1050 # calls against the call queue. 1051 # method_to_mock: The actual method being mocked, used for 1052 # introspection. 1053 # description: optionally, a descriptive name for this method. 1054 # Typically this is equal to the descriptive name of 1055 # the method's class. 1056 # class_to_bind: optionally, a class that is used for unbound 1057 # methods (or functions in Python3) to which method 1058 # is bound, in order not to loose binding 1059 # information. If given, it will be used for 1060 # checking the type of first method parameter 1061 method_name: str 1062 call_queue: list or deque 1063 replay_mode: bool 1064 method_to_mock: a method object 1065 description: str or None 1066 class_to_bind: type or None 1067 """ 1068 1069 self._name = method_name 1070 self.__name__ = method_name 1071 self._call_queue = call_queue 1072 if not isinstance(call_queue, collections.deque): 1073 self._call_queue = collections.deque(self._call_queue) 1074 self._replay_mode = replay_mode 1075 self._description = description 1076 1077 self._params = None 1078 self._named_params = None 1079 self._return_value = None 1080 self._exception = None 1081 self._side_effects = None 1082 1083 try: 1084 self._checker = MethodSignatureChecker(method_to_mock, 1085 class_to_bind=class_to_bind) 1086 except ValueError: 1087 self._checker = None 1088 1089 def __call__(self, *params, **named_params): 1090 """Log parameters and return the specified return value. 1091 1092 If the Mock(Anything/Object) associated with this call is in record 1093 mode, this MockMethod will be pushed onto the expected call queue. 1094 If the mock is in replay mode, this will pop a MockMethod off the 1095 top of the queue and verify this call is equal to the expected call. 1096 1097 Raises: 1098 UnexpectedMethodCall if this call is supposed to match an expected 1099 method call and it does not. 1100 """ 1101 1102 self._params = params 1103 self._named_params = named_params 1104 1105 if not self._replay_mode: 1106 if self._checker is not None: 1107 self._checker.Check(params, named_params) 1108 self._call_queue.append(self) 1109 return self 1110 1111 expected_method = self._VerifyMethodCall() 1112 1113 if expected_method._side_effects: 1114 result = expected_method._side_effects(*params, **named_params) 1115 if expected_method._return_value is None: 1116 expected_method._return_value = result 1117 1118 if expected_method._exception: 1119 raise expected_method._exception 1120 1121 return expected_method._return_value 1122 1123 def __getattr__(self, name): 1124 """Raise an AttributeError with a helpful message.""" 1125 1126 raise AttributeError( 1127 'MockMethod has no attribute "%s". ' 1128 'Did you remember to put your mocks in replay mode?' % name) 1129 1130 def __iter__(self): 1131 """Raise a TypeError with a helpful message.""" 1132 raise TypeError( 1133 'MockMethod cannot be iterated. ' 1134 'Did you remember to put your mocks in replay mode?') 1135 1136 def next(self): 1137 """Raise a TypeError with a helpful message.""" 1138 raise TypeError( 1139 'MockMethod cannot be iterated. ' 1140 'Did you remember to put your mocks in replay mode?') 1141 1142 def __next__(self): 1143 """Raise a TypeError with a helpful message.""" 1144 raise TypeError( 1145 'MockMethod cannot be iterated. ' 1146 'Did you remember to put your mocks in replay mode?') 1147 1148 def _PopNextMethod(self): 1149 """Pop the next method from our call queue.""" 1150 try: 1151 return self._call_queue.popleft() 1152 except IndexError: 1153 raise UnexpectedMethodCallError(self, None) 1154 1155 def _VerifyMethodCall(self): 1156 """Verify the called method is expected. 1157 1158 This can be an ordered method, or part of an unordered set. 1159 1160 Returns: 1161 The expected mock method. 1162 1163 Raises: 1164 UnexpectedMethodCall if the method called was not expected. 1165 """ 1166 1167 expected = self._PopNextMethod() 1168 1169 # Loop here, because we might have a MethodGroup followed by another 1170 # group. 1171 while isinstance(expected, MethodGroup): 1172 expected, method = expected.MethodCalled(self) 1173 if method is not None: 1174 return method 1175 1176 # This is a mock method, so just check equality. 1177 if expected != self: 1178 raise UnexpectedMethodCallError(self, expected) 1179 1180 return expected 1181 1182 def __str__(self): 1183 params = ', '.join( 1184 [repr(p) for p in self._params or []] + 1185 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) 1186 full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) 1187 if self._description: 1188 full_desc = "%s.%s" % (self._description, full_desc) 1189 return full_desc 1190 1191 def __hash__(self): 1192 return id(self) 1193 1194 def __eq__(self, rhs): 1195 """Test whether this MockMethod is equivalent to another MockMethod. 1196 1197 Args: 1198 # rhs: the right hand side of the test 1199 rhs: MockMethod 1200 """ 1201 1202 return (isinstance(rhs, MockMethod) and 1203 self._name == rhs._name and 1204 self._params == rhs._params and 1205 self._named_params == rhs._named_params) 1206 1207 def __ne__(self, rhs): 1208 """Test if this MockMethod is not equivalent to another MockMethod. 1209 1210 Args: 1211 # rhs: the right hand side of the test 1212 rhs: MockMethod 1213 """ 1214 1215 return not self == rhs 1216 1217 def GetPossibleGroup(self): 1218 """Returns a possible group from the end of the call queue. 1219 1220 Return None if no other methods are on the stack. 1221 """ 1222 1223 # Remove this method from the tail of the queue so we can add it 1224 # to a group. 1225 this_method = self._call_queue.pop() 1226 assert this_method == self 1227 1228 # Determine if the tail of the queue is a group, or just a regular 1229 # ordered mock method. 1230 group = None 1231 try: 1232 group = self._call_queue[-1] 1233 except IndexError: 1234 pass 1235 1236 return group 1237 1238 def _CheckAndCreateNewGroup(self, group_name, group_class): 1239 """Checks if the last method (a possible group) is an instance of our 1240 group_class. Adds the current method to this group or creates a 1241 new one. 1242 1243 Args: 1244 1245 group_name: the name of the group. 1246 group_class: the class used to create instance of this new group 1247 """ 1248 group = self.GetPossibleGroup() 1249 1250 # If this is a group, and it is the correct group, add the method. 1251 if isinstance(group, group_class) and group.group_name() == group_name: 1252 group.AddMethod(self) 1253 return self 1254 1255 # Create a new group and add the method. 1256 new_group = group_class(group_name) 1257 new_group.AddMethod(self) 1258 self._call_queue.append(new_group) 1259 return self 1260 1261 def InAnyOrder(self, group_name="default"): 1262 """Move this method into a group of unordered calls. 1263 1264 A group of unordered calls must be defined together, and must be 1265 executed in full before the next expected method can be called. 1266 There can be multiple groups that are expected serially, if they are 1267 given different group names. The same group name can be reused if there 1268 is a standard method call, or a group with a different name, spliced 1269 between usages. 1270 1271 Args: 1272 group_name: the name of the unordered group. 1273 1274 Returns: 1275 self 1276 """ 1277 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) 1278 1279 def MultipleTimes(self, group_name="default"): 1280 """Move method into group of calls which may be called multiple times. 1281 1282 A group of repeating calls must be defined together, and must be 1283 executed in full before the next expected method can be called. 1284 1285 Args: 1286 group_name: the name of the unordered group. 1287 1288 Returns: 1289 self 1290 """ 1291 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) 1292 1293 def AndReturn(self, return_value): 1294 """Set the value to return when this method is called. 1295 1296 Args: 1297 # return_value can be anything. 1298 """ 1299 1300 self._return_value = return_value 1301 return return_value 1302 1303 def AndRaise(self, exception): 1304 """Set the exception to raise when this method is called. 1305 1306 Args: 1307 # exception: the exception to raise when this method is called. 1308 exception: Exception 1309 """ 1310 1311 self._exception = exception 1312 1313 def WithSideEffects(self, side_effects): 1314 """Set the side effects that are simulated when this method is called. 1315 1316 Args: 1317 side_effects: A callable which modifies the parameters or other 1318 relevant state which a given test case depends on. 1319 1320 Returns: 1321 Self for chaining with AndReturn and AndRaise. 1322 """ 1323 self._side_effects = side_effects 1324 return self 1325 1326 1327class Comparator: 1328 """Base class for all Mox comparators. 1329 1330 A Comparator can be used as a parameter to a mocked method when the exact 1331 value is not known. For example, the code you are testing might build up 1332 a long SQL string that is passed to your mock DAO. You're only interested 1333 that the IN clause contains the proper primary keys, so you can set your 1334 mock up as follows: 1335 1336 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 1337 1338 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. 1339 1340 A Comparator may replace one or more parameters, for example: 1341 # return at most 10 rows 1342 mock_dao.RunQuery(StrContains('SELECT'), 10) 1343 1344 or 1345 1346 # Return some non-deterministic number of rows 1347 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) 1348 """ 1349 1350 def equals(self, rhs): 1351 """Special equals method that all comparators must implement. 1352 1353 Args: 1354 rhs: any python object 1355 """ 1356 1357 raise NotImplementedError('method must be implemented by a subclass.') 1358 1359 def __eq__(self, rhs): 1360 return self.equals(rhs) 1361 1362 def __ne__(self, rhs): 1363 return not self.equals(rhs) 1364 1365 1366class Is(Comparator): 1367 """Comparison class used to check identity, instead of equality.""" 1368 1369 def __init__(self, obj): 1370 self._obj = obj 1371 1372 def equals(self, rhs): 1373 return rhs is self._obj 1374 1375 def __repr__(self): 1376 return "<is %r (%s)>" % (self._obj, id(self._obj)) 1377 1378 1379class IsA(Comparator): 1380 """This class wraps a basic Python type or class. It is used to verify 1381 that a parameter is of the given type or class. 1382 1383 Example: 1384 mock_dao.Connect(IsA(DbConnectInfo)) 1385 """ 1386 1387 def __init__(self, class_name): 1388 """Initialize IsA 1389 1390 Args: 1391 class_name: basic python type or a class 1392 """ 1393 1394 self._class_name = class_name 1395 1396 def equals(self, rhs): 1397 """Check to see if the RHS is an instance of class_name. 1398 1399 Args: 1400 # rhs: the right hand side of the test 1401 rhs: object 1402 1403 Returns: 1404 bool 1405 """ 1406 1407 try: 1408 return isinstance(rhs, self._class_name) 1409 except TypeError: 1410 # Check raw types if there was a type error. This is helpful for 1411 # things like cStringIO.StringIO. 1412 return type(rhs) == type(self._class_name) 1413 1414 def _IsSubClass(self, clazz): 1415 """Check to see if the IsA comparators class is a subclass of clazz. 1416 1417 Args: 1418 # clazz: a class object 1419 1420 Returns: 1421 bool 1422 """ 1423 1424 try: 1425 return issubclass(self._class_name, clazz) 1426 except TypeError: 1427 # Check raw types if there was a type error. This is helpful for 1428 # things like cStringIO.StringIO. 1429 return type(clazz) == type(self._class_name) 1430 1431 def __repr__(self): 1432 return 'mox.IsA(%s) ' % str(self._class_name) 1433 1434 1435class IsAlmost(Comparator): 1436 """Comparison class used to check whether a parameter is nearly equal 1437 to a given value. Generally useful for floating point numbers. 1438 1439 Example mock_dao.SetTimeout((IsAlmost(3.9))) 1440 """ 1441 1442 def __init__(self, float_value, places=7): 1443 """Initialize IsAlmost. 1444 1445 Args: 1446 float_value: The value for making the comparison. 1447 places: The number of decimal places to round to. 1448 """ 1449 1450 self._float_value = float_value 1451 self._places = places 1452 1453 def equals(self, rhs): 1454 """Check to see if RHS is almost equal to float_value 1455 1456 Args: 1457 rhs: the value to compare to float_value 1458 1459 Returns: 1460 bool 1461 """ 1462 1463 try: 1464 return round(rhs - self._float_value, self._places) == 0 1465 except Exception: 1466 # Probably because either float_value or rhs is not a number. 1467 return False 1468 1469 def __repr__(self): 1470 return str(self._float_value) 1471 1472 1473class StrContains(Comparator): 1474 """Comparison class used to check whether a substring exists in a 1475 string parameter. This can be useful in mocking a database with SQL 1476 passed in as a string parameter, for example. 1477 1478 Example: 1479 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 1480 """ 1481 1482 def __init__(self, search_string): 1483 """Initialize. 1484 1485 Args: 1486 # search_string: the string you are searching for 1487 search_string: str 1488 """ 1489 1490 self._search_string = search_string 1491 1492 def equals(self, rhs): 1493 """Check to see if the search_string is contained in the rhs string. 1494 1495 Args: 1496 # rhs: the right hand side of the test 1497 rhs: object 1498 1499 Returns: 1500 bool 1501 """ 1502 1503 try: 1504 return rhs.find(self._search_string) > -1 1505 except Exception: 1506 return False 1507 1508 def __repr__(self): 1509 return '<str containing \'%s\'>' % self._search_string 1510 1511 1512class Regex(Comparator): 1513 """Checks if a string matches a regular expression. 1514 1515 This uses a given regular expression to determine equality. 1516 """ 1517 1518 def __init__(self, pattern, flags=0): 1519 """Initialize. 1520 1521 Args: 1522 # pattern is the regular expression to search for 1523 pattern: str 1524 # flags passed to re.compile function as the second argument 1525 flags: int 1526 """ 1527 self.flags = flags 1528 self.regex = re.compile(pattern, flags=flags) 1529 1530 def equals(self, rhs): 1531 """Check to see if rhs matches regular expression pattern. 1532 1533 Returns: 1534 bool 1535 """ 1536 1537 try: 1538 return self.regex.search(rhs) is not None 1539 except Exception: 1540 return False 1541 1542 def __repr__(self): 1543 s = '<regular expression \'%s\'' % self.regex.pattern 1544 if self.flags: 1545 s += ', flags=%d' % self.flags 1546 s += '>' 1547 return s 1548 1549 1550class In(Comparator): 1551 """Checks whether an item (or key) is in a list (or dict) parameter. 1552 1553 Example: 1554 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) 1555 """ 1556 1557 def __init__(self, key): 1558 """Initialize. 1559 1560 Args: 1561 # key is any thing that could be in a list or a key in a dict 1562 """ 1563 1564 self._key = key 1565 1566 def equals(self, rhs): 1567 """Check to see whether key is in rhs. 1568 1569 Args: 1570 rhs: dict 1571 1572 Returns: 1573 bool 1574 """ 1575 1576 try: 1577 return self._key in rhs 1578 except Exception: 1579 return False 1580 1581 def __repr__(self): 1582 return '<sequence or map containing \'%s\'>' % str(self._key) 1583 1584 1585class Not(Comparator): 1586 """Checks whether a predicates is False. 1587 1588 Example: 1589 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', 1590 stevepm_user_info))) 1591 """ 1592 1593 def __init__(self, predicate): 1594 """Initialize. 1595 1596 Args: 1597 # predicate: a Comparator instance. 1598 """ 1599 1600 assert isinstance(predicate, Comparator), ("predicate %r must be a" 1601 " Comparator." % predicate) 1602 self._predicate = predicate 1603 1604 def equals(self, rhs): 1605 """Check to see whether the predicate is False. 1606 1607 Args: 1608 rhs: A value that will be given in argument of the predicate. 1609 1610 Returns: 1611 bool 1612 """ 1613 1614 try: 1615 return not self._predicate.equals(rhs) 1616 except Exception: 1617 return False 1618 1619 def __repr__(self): 1620 return '<not \'%s\'>' % self._predicate 1621 1622 1623class ContainsKeyValue(Comparator): 1624 """Checks whether a key/value pair is in a dict parameter. 1625 1626 Example: 1627 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) 1628 """ 1629 1630 def __init__(self, key, value): 1631 """Initialize. 1632 1633 Args: 1634 # key: a key in a dict 1635 # value: the corresponding value 1636 """ 1637 1638 self._key = key 1639 self._value = value 1640 1641 def equals(self, rhs): 1642 """Check whether the given key/value pair is in the rhs dict. 1643 1644 Returns: 1645 bool 1646 """ 1647 1648 try: 1649 return rhs[self._key] == self._value 1650 except Exception: 1651 return False 1652 1653 def __repr__(self): 1654 return '<map containing the entry \'%s: %s\'>' % (str(self._key), 1655 str(self._value)) 1656 1657 1658class ContainsAttributeValue(Comparator): 1659 """Checks whether passed parameter contains attributes with a given value. 1660 1661 Example: 1662 mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info)) 1663 """ 1664 1665 def __init__(self, key, value): 1666 """Initialize. 1667 1668 Args: 1669 # key: an attribute name of an object 1670 # value: the corresponding value 1671 """ 1672 1673 self._key = key 1674 self._value = value 1675 1676 def equals(self, rhs): 1677 """Check if the given attribute has a matching value in the rhs object. 1678 1679 Returns: 1680 bool 1681 """ 1682 1683 try: 1684 return getattr(rhs, self._key) == self._value 1685 except Exception: 1686 return False 1687 1688 1689class SameElementsAs(Comparator): 1690 """Checks whether sequences contain the same elements (ignoring order). 1691 1692 Example: 1693 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) 1694 """ 1695 1696 def __init__(self, expected_seq): 1697 """Initialize. 1698 1699 Args: 1700 expected_seq: a sequence 1701 """ 1702 # Store in case expected_seq is an iterator. 1703 self._expected_list = list(expected_seq) 1704 1705 def equals(self, actual_seq): 1706 """Check to see whether actual_seq has same elements as expected_seq. 1707 1708 Args: 1709 actual_seq: sequence 1710 1711 Returns: 1712 bool 1713 """ 1714 try: 1715 # Store in case actual_seq is an iterator. We potentially iterate 1716 # twice: once to make the dict, once in the list fallback. 1717 actual_list = list(actual_seq) 1718 except TypeError: 1719 # actual_seq cannot be read as a sequence. 1720 # 1721 # This happens because Mox uses __eq__ both to check object 1722 # equality (in MethodSignatureChecker) and to invoke Comparators. 1723 return False 1724 1725 try: 1726 return set(self._expected_list) == set(actual_list) 1727 except TypeError: 1728 # Fall back to slower list-compare if any of the objects 1729 # are unhashable. 1730 if len(self._expected_list) != len(actual_list): 1731 return False 1732 for el in actual_list: 1733 if el not in self._expected_list: 1734 return False 1735 return True 1736 1737 def __repr__(self): 1738 return '<sequence with same elements as \'%s\'>' % self._expected_list 1739 1740 1741class And(Comparator): 1742 """Evaluates one or more Comparators on RHS, returns an AND of the results. 1743 """ 1744 1745 def __init__(self, *args): 1746 """Initialize. 1747 1748 Args: 1749 *args: One or more Comparator 1750 """ 1751 1752 self._comparators = args 1753 1754 def equals(self, rhs): 1755 """Checks whether all Comparators are equal to rhs. 1756 1757 Args: 1758 # rhs: can be anything 1759 1760 Returns: 1761 bool 1762 """ 1763 1764 for comparator in self._comparators: 1765 if not comparator.equals(rhs): 1766 return False 1767 1768 return True 1769 1770 def __repr__(self): 1771 return '<AND %s>' % str(self._comparators) 1772 1773 1774class Or(Comparator): 1775 """Evaluates one or more Comparators on RHS; returns OR of the results.""" 1776 1777 def __init__(self, *args): 1778 """Initialize. 1779 1780 Args: 1781 *args: One or more Mox comparators 1782 """ 1783 1784 self._comparators = args 1785 1786 def equals(self, rhs): 1787 """Checks whether any Comparator is equal to rhs. 1788 1789 Args: 1790 # rhs: can be anything 1791 1792 Returns: 1793 bool 1794 """ 1795 1796 for comparator in self._comparators: 1797 if comparator.equals(rhs): 1798 return True 1799 1800 return False 1801 1802 def __repr__(self): 1803 return '<OR %s>' % str(self._comparators) 1804 1805 1806class Func(Comparator): 1807 """Call a function that should verify the parameter passed in is correct. 1808 1809 You may need the ability to perform more advanced operations on the 1810 parameter in order to validate it. You can use this to have a callable 1811 validate any parameter. The callable should return either True or False. 1812 1813 1814 Example: 1815 1816 def myParamValidator(param): 1817 # Advanced logic here 1818 return True 1819 1820 mock_dao.DoSomething(Func(myParamValidator), true) 1821 """ 1822 1823 def __init__(self, func): 1824 """Initialize. 1825 1826 Args: 1827 func: callable that takes one parameter and returns a bool 1828 """ 1829 1830 self._func = func 1831 1832 def equals(self, rhs): 1833 """Test whether rhs passes the function test. 1834 1835 rhs is passed into func. 1836 1837 Args: 1838 rhs: any python object 1839 1840 Returns: 1841 the result of func(rhs) 1842 """ 1843 1844 return self._func(rhs) 1845 1846 def __repr__(self): 1847 return str(self._func) 1848 1849 1850class IgnoreArg(Comparator): 1851 """Ignore an argument. 1852 1853 This can be used when we don't care about an argument of a method call. 1854 1855 Example: 1856 # Check if CastMagic is called with 3 as first arg and 1857 # 'disappear' as third. 1858 mymock.CastMagic(3, IgnoreArg(), 'disappear') 1859 """ 1860 1861 def equals(self, unused_rhs): 1862 """Ignores arguments and returns True. 1863 1864 Args: 1865 unused_rhs: any python object 1866 1867 Returns: 1868 always returns True 1869 """ 1870 1871 return True 1872 1873 def __repr__(self): 1874 return '<IgnoreArg>' 1875 1876 1877class Value(Comparator): 1878 """Compares argument against a remembered value. 1879 1880 To be used in conjunction with Remember comparator. See Remember() 1881 for example. 1882 """ 1883 1884 def __init__(self): 1885 self._value = None 1886 self._has_value = False 1887 1888 def store_value(self, rhs): 1889 self._value = rhs 1890 self._has_value = True 1891 1892 def equals(self, rhs): 1893 if not self._has_value: 1894 return False 1895 else: 1896 return rhs == self._value 1897 1898 def __repr__(self): 1899 if self._has_value: 1900 return "<Value %r>" % self._value 1901 else: 1902 return "<Value>" 1903 1904 1905class Remember(Comparator): 1906 """Remembers the argument to a value store. 1907 1908 To be used in conjunction with Value comparator. 1909 1910 Example: 1911 # Remember the argument for one method call. 1912 users_list = Value() 1913 mock_dao.ProcessUsers(Remember(users_list)) 1914 1915 # Check argument against remembered value. 1916 mock_dao.ReportUsers(users_list) 1917 """ 1918 1919 def __init__(self, value_store): 1920 if not isinstance(value_store, Value): 1921 raise TypeError( 1922 "value_store is not an instance of the Value class") 1923 self._value_store = value_store 1924 1925 def equals(self, rhs): 1926 self._value_store.store_value(rhs) 1927 return True 1928 1929 def __repr__(self): 1930 return "<Remember %d>" % id(self._value_store) 1931 1932 1933class MethodGroup(object): 1934 """Base class containing common behaviour for MethodGroups.""" 1935 1936 def __init__(self, group_name): 1937 self._group_name = group_name 1938 1939 def group_name(self): 1940 return self._group_name 1941 1942 def __str__(self): 1943 return '<%s "%s">' % (self.__class__.__name__, self._group_name) 1944 1945 def AddMethod(self, mock_method): 1946 raise NotImplementedError 1947 1948 def MethodCalled(self, mock_method): 1949 raise NotImplementedError 1950 1951 def IsSatisfied(self): 1952 raise NotImplementedError 1953 1954 1955class UnorderedGroup(MethodGroup): 1956 """UnorderedGroup holds a set of method calls that may occur in any order. 1957 1958 This construct is helpful for non-deterministic events, such as iterating 1959 over the keys of a dict. 1960 """ 1961 1962 def __init__(self, group_name): 1963 super(UnorderedGroup, self).__init__(group_name) 1964 self._methods = [] 1965 1966 def __str__(self): 1967 return '%s "%s" pending calls:\n%s' % ( 1968 self.__class__.__name__, 1969 self._group_name, 1970 "\n".join(str(method) for method in self._methods)) 1971 1972 def AddMethod(self, mock_method): 1973 """Add a method to this group. 1974 1975 Args: 1976 mock_method: A mock method to be added to this group. 1977 """ 1978 1979 self._methods.append(mock_method) 1980 1981 def MethodCalled(self, mock_method): 1982 """Remove a method call from the group. 1983 1984 If the method is not in the set, an UnexpectedMethodCallError will be 1985 raised. 1986 1987 Args: 1988 mock_method: a mock method that should be equal to a method in the 1989 group. 1990 1991 Returns: 1992 The mock method from the group 1993 1994 Raises: 1995 UnexpectedMethodCallError if the mock_method was not in the group. 1996 """ 1997 1998 # Check to see if this method exists, and if so, remove it from the set 1999 # and return it. 2000 for method in self._methods: 2001 if method == mock_method: 2002 # Remove the called mock_method instead of the method in the 2003 # group. The called method will match any comparators when 2004 # equality is checked during removal. The method in the group 2005 # could pass a comparator to another comparator during the 2006 # equality check. 2007 self._methods.remove(mock_method) 2008 2009 # If group is not empty, put it back at the head of the queue. 2010 if not self.IsSatisfied(): 2011 mock_method._call_queue.appendleft(self) 2012 2013 return self, method 2014 2015 raise UnexpectedMethodCallError(mock_method, self) 2016 2017 def IsSatisfied(self): 2018 """Return True if there are not any methods in this group.""" 2019 2020 return len(self._methods) == 0 2021 2022 2023class MultipleTimesGroup(MethodGroup): 2024 """MultipleTimesGroup holds methods that may be called any number of times. 2025 2026 Note: Each method must be called at least once. 2027 2028 This is helpful, if you don't know or care how many times a method is 2029 called. 2030 """ 2031 2032 def __init__(self, group_name): 2033 super(MultipleTimesGroup, self).__init__(group_name) 2034 self._methods = set() 2035 self._methods_left = set() 2036 2037 def AddMethod(self, mock_method): 2038 """Add a method to this group. 2039 2040 Args: 2041 mock_method: A mock method to be added to this group. 2042 """ 2043 2044 self._methods.add(mock_method) 2045 self._methods_left.add(mock_method) 2046 2047 def MethodCalled(self, mock_method): 2048 """Remove a method call from the group. 2049 2050 If the method is not in the set, an UnexpectedMethodCallError will be 2051 raised. 2052 2053 Args: 2054 mock_method: a mock method that should be equal to a method in the 2055 group. 2056 2057 Returns: 2058 The mock method from the group 2059 2060 Raises: 2061 UnexpectedMethodCallError if the mock_method was not in the group. 2062 """ 2063 2064 # Check to see if this method exists, and if so add it to the set of 2065 # called methods. 2066 for method in self._methods: 2067 if method == mock_method: 2068 self._methods_left.discard(method) 2069 # Always put this group back on top of the queue, 2070 # because we don't know when we are done. 2071 mock_method._call_queue.appendleft(self) 2072 return self, method 2073 2074 if self.IsSatisfied(): 2075 next_method = mock_method._PopNextMethod() 2076 return next_method, None 2077 else: 2078 raise UnexpectedMethodCallError(mock_method, self) 2079 2080 def IsSatisfied(self): 2081 """Return True if all methods in group are called at least once.""" 2082 return len(self._methods_left) == 0 2083 2084 2085class MoxMetaTestBase(type): 2086 """Metaclass to add mox cleanup and verification to every test. 2087 2088 As the mox unit testing class is being constructed (MoxTestBase or a 2089 subclass), this metaclass will modify all test functions to call the 2090 CleanUpMox method of the test class after they finish. This means that 2091 unstubbing and verifying will happen for every test with no additional 2092 code, and any failures will result in test failures as opposed to errors. 2093 """ 2094 2095 def __init__(cls, name, bases, d): 2096 type.__init__(cls, name, bases, d) 2097 2098 # also get all the attributes from the base classes to account 2099 # for a case when test class is not the immediate child of MoxTestBase 2100 for base in bases: 2101 for attr_name in dir(base): 2102 if attr_name not in d: 2103 d[attr_name] = getattr(base, attr_name) 2104 2105 for func_name, func in d.items(): 2106 if func_name.startswith('test') and callable(func): 2107 2108 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) 2109 2110 @staticmethod 2111 def CleanUpTest(cls, func): 2112 """Adds Mox cleanup code to any MoxTestBase method. 2113 2114 Always unsets stubs after a test. Will verify all mocks for tests that 2115 otherwise pass. 2116 2117 Args: 2118 cls: MoxTestBase or subclass; the class whose method we are 2119 altering. 2120 func: method; the method of the MoxTestBase test class we wish to 2121 alter. 2122 2123 Returns: 2124 The modified method. 2125 """ 2126 def new_method(self, *args, **kwargs): 2127 mox_obj = getattr(self, 'mox', None) 2128 stubout_obj = getattr(self, 'stubs', None) 2129 cleanup_mox = False 2130 cleanup_stubout = False 2131 if mox_obj and isinstance(mox_obj, Mox): 2132 cleanup_mox = True 2133 if stubout_obj and isinstance(stubout_obj, 2134 stubout.StubOutForTesting): 2135 cleanup_stubout = True 2136 try: 2137 func(self, *args, **kwargs) 2138 finally: 2139 if cleanup_mox: 2140 mox_obj.UnsetStubs() 2141 if cleanup_stubout: 2142 stubout_obj.UnsetAll() 2143 stubout_obj.SmartUnsetAll() 2144 if cleanup_mox: 2145 mox_obj.VerifyAll() 2146 new_method.__name__ = func.__name__ 2147 new_method.__doc__ = func.__doc__ 2148 new_method.__module__ = func.__module__ 2149 return new_method 2150 2151 2152_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {}) 2153 2154 2155class MoxTestBase(_MoxTestBase): 2156 """Convenience test class to make stubbing easier. 2157 2158 Sets up a "mox" attribute which is an instance of Mox (any mox tests will 2159 want this), and a "stubs" attribute that is an instance of 2160 StubOutForTesting (needed at times). Also automatically unsets any stubs 2161 and verifies that all mock methods have been called at the end of each 2162 test, eliminating boilerplate code. 2163 """ 2164 2165 def setUp(self): 2166 super(MoxTestBase, self).setUp() 2167 self.mox = Mox() 2168 self.stubs = stubout.StubOutForTesting() 2169