1# Copyright 2012 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4"""A very very simple mock object harness.""" 5from types import ModuleType 6 7DONT_CARE = '' 8 9class MockFunctionCall(object): 10 def __init__(self, name): 11 self.name = name 12 self.args = tuple() 13 self.return_value = None 14 self.when_called_handlers = [] 15 16 def WithArgs(self, *args): 17 self.args = args 18 return self 19 20 def WillReturn(self, value): 21 self.return_value = value 22 return self 23 24 def WhenCalled(self, handler): 25 self.when_called_handlers.append(handler) 26 27 def VerifyEquals(self, got): 28 if self.name != got.name: 29 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 30 if len(self.args) != len(got.args): 31 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 32 for i in range(len(self.args)): 33 self_a = self.args[i] 34 got_a = got.args[i] 35 if self_a == DONT_CARE: 36 continue 37 if self_a != got_a: 38 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 39 40 def __repr__(self): 41 def arg_to_text(a): 42 if a == DONT_CARE: 43 return '_' 44 return repr(a) 45 args_text = ', '.join([arg_to_text(a) for a in self.args]) 46 if self.return_value in (None, DONT_CARE): 47 return '%s(%s)' % (self.name, args_text) 48 return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value)) 49 50class MockTrace(object): 51 def __init__(self): 52 self.expected_calls = [] 53 self.next_call_index = 0 54 55class MockObject(object): 56 def __init__(self, parent_mock=None): 57 if parent_mock: 58 self._trace = parent_mock._trace # pylint: disable=protected-access 59 else: 60 self._trace = MockTrace() 61 62 def __setattr__(self, name, value): 63 if (not hasattr(self, '_trace') or 64 hasattr(value, 'is_hook')): 65 object.__setattr__(self, name, value) 66 return 67 assert isinstance(value, MockObject) 68 object.__setattr__(self, name, value) 69 70 def SetAttribute(self, name, value): 71 setattr(self, name, value) 72 73 def ExpectCall(self, func_name, *args): 74 assert self._trace.next_call_index == 0 75 if not hasattr(self, func_name): 76 self._install_hook(func_name) 77 78 call = MockFunctionCall(func_name) 79 self._trace.expected_calls.append(call) 80 call.WithArgs(*args) 81 return call 82 83 def _install_hook(self, func_name): 84 def handler(*args, **_): 85 got_call = MockFunctionCall( 86 func_name).WithArgs(*args).WillReturn(DONT_CARE) 87 if self._trace.next_call_index >= len(self._trace.expected_calls): 88 raise Exception( 89 'Call to %s was not expected, at end of programmed trace.' % 90 repr(got_call)) 91 expected_call = self._trace.expected_calls[ 92 self._trace.next_call_index] 93 expected_call.VerifyEquals(got_call) 94 self._trace.next_call_index += 1 95 for h in expected_call.when_called_handlers: 96 h(*args) 97 return expected_call.return_value 98 handler.is_hook = True 99 setattr(self, func_name, handler) 100 101 102class MockTimer(object): 103 """ A mock timer to fake out the timing for a module. 104 Args: 105 module: module to fake out the time 106 """ 107 def __init__(self, module=None): 108 self._elapsed_time = 0 109 self._module = module 110 self._actual_time = None 111 if module: 112 assert isinstance(module, ModuleType) 113 self._actual_time = module.time 114 self._module.time = self 115 116 def sleep(self, time): 117 self._elapsed_time += time 118 119 def time(self): 120 return self._elapsed_time 121 122 def SetTime(self, time): 123 self._elapsed_time = time 124 125 def __del__(self): 126 self.Restore() 127 128 def Restore(self): 129 if self._module: 130 self._module.time = self._actual_time 131 self._module = None 132 self._actual_time = None 133