1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import dpkt
6import os
7import select
8import struct
9import sys
10import threading
11import time
12import traceback
13
14
15class SimulatorError(Exception):
16    "A Simulator generic error."
17
18
19class NullContext(object):
20    """A context manager without any functionality."""
21    def __enter__(self):
22        return self
23
24
25    def __exit__(self, exc_type, exc_val, exc_tb):
26        return False # raises the exception if passed.
27
28
29class Simulator(object):
30    """A TUN/TAP network interface simulator class.
31
32    This class allows several implementations of different fake hosts to
33    coexists on the same TUN/TAP interface. It will dispatch the same packet
34    to each one of the registered hosts, providing some basic filtering
35    to simplify these implementations.
36    """
37
38    def __init__(self, iface):
39        """Initialize the instance.
40
41        @param tuntap.TunTap iface: the interface over which this interface
42        runs. Should not be shared with other modules.
43        """
44        self._iface = iface
45        self._rules = []
46        # _events holds a lists of events that need to be fired for each
47        # timestamp stored on the key. The event list is a list of callback
48        # functions that will be called if the simulation reaches that
49        # timestamp. This is used to fire time-based events.
50        self._events = {}
51        self._write_queue = []
52        # A pipe used to wake up the run() method from a diffent thread calling
53        # stop(). See the stop() method for details.
54        self._pipe_rd, self._pipe_wr = os.pipe()
55        self._running = False
56        # Lock object used for _events if multithreading is required.
57        self._lock = NullContext()
58
59
60    def __del__(self):
61        os.close(self._pipe_rd)
62        os.close(self._pipe_wr)
63
64
65    def add_match(self, rule, callback):
66        """Add a new match rule to the outbound traffic.
67
68        This function adds a new rule that will be matched against each packet
69        that the host sends through the interface and will call a callback if
70        it matches. The rule can be specified in the following ways:
71          * A python function that takes a packet as a single argument and
72            returns True when the packet matches.
73          * A dictionary of key=value pairs that all of them need to be matched.
74            A pair matches when the packet has the provided chain of attributes
75            and its value is equal to the provided value. For example, this will
76            match any DNS traffic sent to the host 192.168.0.1:
77            {"ip.dst": socket.inet_aton("192.168.0.1"),
78             "ip.upd.dport": 53}
79
80        @param rule: The rule description.
81        @param callback: A callback function that receives the dpkt packet as
82        the only argument.
83        """
84        if not callable(callback):
85            raise SimulatorError("|callback| must be a callable object.")
86
87        if callable(rule):
88            self._rules.append((rule, callback))
89        if isinstance(rule, dict):
90            rule = dict(rule) # Makes a copy of the dict, but not the contents.
91            self._rules.append((lambda p: self._dict_rule(rule, p), callback))
92        else:
93            raise SimulatorError("Unknown rule format: %r" % rule)
94
95
96    def add_timeout(self, timeout, callback):
97        """Add a new callback function to be called after a timeout.
98
99        This method schedules the given |callback| to be called after |timeout|
100        seconds. The callback will be called at most once while the simulator
101        is running (see the run() method). To have a repetitive event call again
102        add_timeout() from the callback.
103
104        @param timeout: The rule description.
105        @param callback: A callback function that doesn't receive any argument.
106        """
107        if not callable(callback):
108            raise SimulatorError("|callback| must be a callable object.")
109        timestamp = time.time() + timeout
110        with self._lock:
111            if timestamp not in self._events:
112                self._events[timestamp] = [callback]
113            else:
114                self._events[timestamp].append(callback)
115
116
117    def remove_timeout(self, callback):
118        """Removes the every scheduled timeout call to the passed callback.
119
120        When a callable object is passed to add_timeout() it is scheduled to be
121        called once the timeout is reached. This method removes all the
122        scheduled calls to that object.
123
124        @param callback: The callable object passed to add_timeout().
125        @return: Wether the callback was found and removed at least once.
126        """
127        removed = False
128        for _ts, ev_list in self._events.iteritems():
129            try:
130                while True:
131                    ev_list.remove(callback)
132                    removed = True
133            except ValueError:
134                pass
135        return removed
136
137
138    def _dict_rule(self, rules, pkt):
139        """Returns wether a given packet matches a set of rules.
140
141        The maching rules passed in |rules| need to be a dict() as described
142        on the add_match() method. The packet |pkt| is any dpkt packet.
143        """
144        for key, value in rules.iteritems():
145            p = pkt
146            for member in key.split('.'):
147                if not hasattr(p, member):
148                    return False
149                p = getattr(p, member)
150            if p != value:
151                return False
152        return True
153
154
155    def write(self, pkt):
156        """Writes a packet to the network interface.
157
158        @param pkt: The dpkt.Packet to be received on the network interface.
159        """
160        # Converts the dpkt packet to: flags, proto, buffer.
161        self._write_queue.append(struct.pack("!HH", 0, pkt.type) + str(pkt))
162
163
164    def run(self, timeout=None, until=None):
165        """Runs the Simulator.
166
167        This method blocks the caller thread until the timeout is reached (if
168        a timeout is passed), until stop() is called or until the function
169        passed in until returns a True value (if a function is passed);
170        whichever occurs first. stop() can be called from any other thread or
171        from a callback called from this thread.
172
173        @param timeout: The timeout in seconds. Can be a float value, or None
174        for no timeout.
175        @param until: A callable object called during the loop returning True
176        when the loop should stop.
177        """
178        if not self._iface.is_up():
179            raise SimulatorError("Interface is down.")
180
181        stop_callback = None
182        if timeout != None:
183            # We use a newly created callable object to avoid remove another
184            # scheduled call to self.stop.
185            stop_callback = lambda: self.stop()
186            self.add_timeout(timeout, stop_callback)
187
188        self._running = True
189        iface_fd = self._iface.fileno()
190        # Check the until function.
191        while not (until and until()):
192            # The main purpose of this loop is to wait (block) until the next
193            # event is required to be fired. There are four kinds of events:
194            #  * a packet is received.
195            #  * a packet waiting to be sent can now be sent.
196            #  * a time-based event needs to be fired.
197            #  * the simulator was stopped from a different thread.
198            # To achieve this we use select.select() to wait simultaneously on
199            # all those event sources.
200
201            # Fires all the time-based events that need to be fired and computes
202            # the timeout for the next event if there's one.
203            timeout = None
204            cur_time = time.time()
205            with self._lock:
206                if self._events:
207                    # Check events that should be fired.
208                    while self._events and min(self._events) <= cur_time:
209                        key = min(self._events)
210                        lst = self._events[key]
211                        del self._events[key]
212                        for callback in lst:
213                            callback()
214                        cur_time = time.time()
215                # Check if there is an event to attend. Here we know that
216                # min(self._events) > cur_time because the previous while
217                # finished.
218                if self._events:
219                    timeout = min(self._events) - cur_time # in seconds
220
221            # Pool the until() function at least once a second.
222            if timeout is None or timeout > 1.0:
223                timeout = 1.0
224
225            # Compute the list of file descriptors that select.select() needs to
226            # monitor to attend the required events. select() will return when
227            # any of the following occurs:
228            #  * rlist: is possible to read from the interface or another
229            #           thread want's to wake up the simulator loop.
230            #  * wlist: is possible to write to network, if there's a packet
231            #           pending.
232            #  * xlist: an error on the network fd occured. Likely the TAP
233            #           interface was closed.
234            #  * timeout: The previously computed timeout was reached.
235            rlist = iface_fd, self._pipe_rd
236            wlist = tuple()
237            if self._write_queue:
238                wlist = iface_fd,
239            xlist = iface_fd,
240
241            rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout)
242
243            if self._pipe_rd in rlist:
244                msg = os.read(self._pipe_rd, 1)
245                # stop() breaks the loop sending a '*'.
246                if '*' in msg:
247                    break
248                # Other messages are ignored.
249
250            if xlist:
251                break
252
253            if iface_fd in wlist:
254                self._iface.write(self._write_queue.pop(0))
255                # Attempt to send all the scheduled packets before reading more
256                continue
257
258            # Process the given packet:
259            if iface_fd in rlist:
260                raw = self._iface.read()
261                flag, proto = struct.unpack("!HH", raw[:4])
262                pkt = dpkt.ethernet.Ethernet(raw[4:])
263                for rule, callback in self._rules:
264                    if rule(pkt):
265                        # Parse again the packet to allow callbacks to modify
266                        # it.
267                        callback(dpkt.ethernet.Ethernet(raw[4:]))
268
269        if stop_callback:
270            self.remove_timeout(stop_callback)
271        self._running = False
272
273
274    def stop(self):
275        """Stops the run() method if it is running."""
276        os.write(self._pipe_wr, '*')
277
278
279class SimulatorThread(threading.Thread, Simulator):
280    """A threaded version of the Simulator.
281
282    This class exposses a similar interface as the Simulator class with the
283    difference that it runs on its own thread. This exposes an extra method
284    start() that should be called instead of Simulator.run(). start() will make
285    the process run continuosly until stop() is called, after which the
286    simulator can't be restarted.
287
288    The methods used to add new matches can be called from any thread *before*
289    the method start() is caller. After that point, only the callbacks, running
290    from this thread, are allowed to create new matches and timeouts.
291
292    Example:
293        simu = SimulatorThread(tap_interface)
294        simu.add_match({"ip.tcp.dport": 80}, some_callback)
295        simu.start()
296        time.sleep(100)
297        simu.stop()
298        simu.join() # Optional
299    """
300
301    def __init__(self, iface, timeout=None):
302        threading.Thread.__init__(self)
303        Simulator.__init__(self, iface)
304        self._timeout = timeout
305        # We allow the same thread to acquire the lock more than once. This is
306        # useful if a callback want's to add itself.
307        self._lock = threading.RLock()
308        self.error = None
309
310
311    def run_on_simulator(self, callback):
312        """Runs the given callback on the SimulatorThread thread.
313
314        Before calling start() on the SimulatorThread, all the calls seting up
315        the simulator are allowed, but once the thread is running, concurrency
316        problems should be considered. This method runs the provided callback
317        on the simulator.
318
319        @param callback: A callback function without arguments.
320        """
321        self.add_timeout(0, callback)
322        # Wake up the main loop with an ignored message.
323        os.write(self._pipe_wr, ' ')
324
325
326    def wait_for_condition(self, condition, timeout=None):
327        """Blocks until the condition is met or timeout is exceeded.
328
329        This method should be called from a different thread while the simulator
330        thread is running as it blocks the calling thread's execution until a
331        condition is met. The condition function is evaluated in a callback
332        running on the simulator thread and thus can safely access objects owned
333        by the simulator.
334
335        @param condition: A function called on the simulator thread that returns
336        a value indicating if the condition is met.
337        @param timeout: The timeout in seconds. None for no timeout.
338        @return: The value returned by condition the last time it was called.
339        This means that in the event of a timeout, this function will return a
340        value that evaluates to False since the condition wasn't met the last
341        time it was checked.
342        """
343        # Lock and Condition used to wait until the passed condition is met.
344        lock_cond = threading.Lock()
345        cond_var = threading.Condition(lock_cond)
346        # We use a mutable object like the [] to pass the reference by value
347        # to the simulator's callback and let it modify the contents.
348        ret = [None]
349
350        # Create the actual callback that will be running on the simulator
351        # thread and pass a reference to it to keep including it
352        callback = lambda: self._condition_poller(
353                callback, ret, cond_var, condition)
354
355        # Let the simulator keep calling our function, it will keep calling
356        # itself until the condition is met (or we remove it).
357        self.run_on_simulator(callback)
358
359        # Condition variable waiting loop.
360        cur_time = time.time()
361        start_time = cur_time
362        with cond_var:
363            while not ret[0]:
364                if timeout is None:
365                    cond_var.wait()
366                else:
367                    cur_timeout = timeout - (cur_time - start_time)
368                    if cur_timeout < 0:
369                        break
370                    cond_var.wait(cur_timeout)
371                    cur_time = time.time()
372        self.remove_timeout(callback)
373
374        return ret[0]
375
376
377    def _condition_poller(self, callback, ref_value, cond_var, func):
378        """Callback function used to poll for a condition.
379
380        This method keeps scheduling itself in the simulator until the passed
381        condition evaluates to a True value. This effectivelly implements a
382        polling mechanism. See wait_for_condition() for details.
383        """
384        with cond_var:
385            ref_value[0] = func()
386            if ref_value[0]:
387                cond_var.notify()
388            else:
389                self.add_timeout(1., callback)
390
391
392    def run(self):
393        """Runs the simulation on the thread, called by start().
394
395        This method wraps the Simulator.run() to pass the timeout value passed
396        during construction.
397        """
398        try:
399            Simulator.run(self, self._timeout)
400        except Exception, e:
401            self.error = e
402            exc_type, exc_value, exc_traceback = sys.exc_info()
403            self.traceback = ''.join(traceback.format_exception(
404                    exc_type, exc_value, exc_traceback))
405