1# Lint as: python2, python3
2import sys, socket, errno, logging
3from time import time, sleep
4from autotest_lib.client.common_lib import error, utils
5
6# default barrier port
7_DEFAULT_PORT = 11922
8
9def _get_host_from_id(hostid):
10    # Remove any trailing local identifier following a #.
11    # This allows multiple members per host which is particularly
12    # helpful in testing.
13    if not hostid.startswith('#'):
14        return hostid.split('#')[0]
15    else:
16        raise error.BarrierError(
17            "Invalid Host id: Host Address should be specified")
18
19
20class BarrierAbortError(error.BarrierError):
21    """Special BarrierError raised when an explicit abort is requested."""
22
23
24class listen_server(object):
25    """
26    Manages a listening socket for barrier.
27
28    Can be used to run multiple barrier instances with the same listening
29    socket (if they were going to listen on the same port).
30
31    Attributes:
32
33    @attr address: Address to bind to (string).
34    @attr port: Port to bind to.
35    @attr socket: Listening socket object.
36    """
37    def __init__(self, address='', port=_DEFAULT_PORT):
38        """
39        Create a listen_server instance for the given address/port.
40
41        @param address: The address to listen on.
42        @param port: The port to listen on.
43        """
44        self.address = address
45        self.port = port
46        # Open the port so that the listening server can accept incoming
47        # connections.
48        utils.run('iptables -A INPUT -p tcp -m tcp --dport %d -j ACCEPT' %
49                  port)
50        self.socket = self._setup()
51
52
53    def _setup(self):
54        """Create, bind and listen on the listening socket."""
55        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
56        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
57        sock.bind((self.address, self.port))
58        sock.listen(10)
59
60        return sock
61
62
63    def close(self):
64        """Close the listening socket."""
65        self.socket.close()
66
67
68class barrier(object):
69    """Multi-machine barrier support.
70
71    Provides multi-machine barrier mechanism.
72    Execution stops until all members arrive at the barrier.
73
74    Implementation Details:
75    .......................
76
77    When a barrier is forming the main node (first in sort order) in the
78    set accepts connections from each member of the set.  As they arrive
79    they indicate the barrier they are joining and their identifier (their
80    hostname or IP address and optional tag).  They are then asked to wait.
81    When all members are present the main node then checks that each
82    member is still responding via a ping/pong exchange.  If this is
83    successful then everyone has checked in at the barrier.  We then tell
84    everyone they may continue via a rlse message.
85
86    Where the main is not the first to reach the barrier the client
87    connects will fail.  Client will retry until they either succeed in
88    connecting to main or the overall timeout is exceeded.
89
90    As an example here is the exchange for a three node barrier called
91    'TAG'
92
93      MAIN                        CLIENT1         CLIENT2
94        <-------------TAG C1-------------
95        --------------wait-------------->
96                      [...]
97        <-------------TAG C2-----------------------------
98        --------------wait------------------------------>
99                      [...]
100        --------------ping-------------->
101        <-------------pong---------------
102        --------------ping------------------------------>
103        <-------------pong-------------------------------
104                ----- BARRIER conditions MET -----
105        --------------rlse-------------->
106        --------------rlse------------------------------>
107
108    Note that once the last client has responded to pong the barrier is
109    implicitly deemed satisifed, they have all acknowledged their presence.
110    If we fail to send any of the rlse messages the barrier is still a
111    success, the failed host has effectively broken 'right at the beginning'
112    of the post barrier execution window.
113
114    In addition, there is another rendezvous, that makes each node a server
115    and the main a client.  The connection process and usage is still the
116    same but allows barriers from machines that only have a one-way
117    connection initiation.  This is called rendezvous_servers.
118
119    For example:
120        if ME == SERVER:
121            server start
122
123        b = job.barrier(ME, 'server-up', 120)
124        b.rendezvous(CLIENT, SERVER)
125
126        if ME == CLIENT:
127            client run
128
129        b = job.barrier(ME, 'test-complete', 3600)
130        b.rendezvous(CLIENT, SERVER)
131
132        if ME == SERVER:
133            server stop
134
135    Any client can also request an abort of the job by setting
136    abort=True in the rendezvous arguments.
137    """
138
139    def __init__(self, hostid, tag, timeout=None, port=None,
140                 listen_server=None):
141        """
142        @param hostid: My hostname/IP address + optional tag.
143        @param tag: Symbolic name of the barrier in progress.
144        @param timeout: Maximum seconds to wait for a the barrier to meet.
145        @param port: Port number to listen on.
146        @param listen_server: External listen_server instance to use instead
147                of creating our own.  Create a listen_server instance and
148                reuse it across multiple barrier instances so that the
149                barrier code doesn't try to quickly re-bind on the same port
150                (packets still in transit for the previous barrier they may
151                reset new connections).
152        """
153        self._hostid = hostid
154        self._tag = tag
155        if listen_server:
156            if port:
157                raise error.BarrierError(
158                        '"port" and "listen_server" are mutually exclusive.')
159            self._port = listen_server.port
160        else:
161            self._port = port or _DEFAULT_PORT
162        self._server = listen_server  # A listen_server instance or None.
163        self._members = []  # List of hosts we expect to find at the barrier.
164        self._timeout_secs = timeout
165        self._start_time = None  # Timestamp of when we started waiting.
166        self._mainid = None  # Host/IP + optional tag of selected main.
167        logging.info("tag=%s port=%d timeout=%r",
168                     self._tag, self._port, self._timeout_secs)
169
170        # Number of clients seen (should be the length of self._waiting).
171        self._seen = 0
172
173        # Clients who have checked in and are waiting (if we are a main).
174        self._waiting = {}  # Maps from hostname -> (client, addr) tuples.
175
176
177    def _update_timeout(self, timeout):
178        if timeout is not None and self._start_time is not None:
179            self._timeout_secs = (time() - self._start_time) + timeout
180        else:
181            self._timeout_secs = timeout
182
183
184    def _remaining(self):
185        if self._timeout_secs is not None and self._start_time is not None:
186            timeout = self._timeout_secs - (time() - self._start_time)
187            if timeout <= 0:
188                errmsg = "timeout waiting for barrier: %s" % self._tag
189                logging.error(error)
190                raise error.BarrierError(errmsg)
191        else:
192            timeout = self._timeout_secs
193
194        if self._timeout_secs is not None:
195            logging.info("seconds remaining: %d", timeout)
196        return timeout
197
198
199    def _main_welcome(self, connection):
200        client, addr = connection
201        name = None
202
203        client.settimeout(5)
204        try:
205            # Get the clients name.
206            intro = client.recv(1024)
207            intro = intro.strip("\r\n")
208
209            intro_parts = intro.split(' ', 2)
210            if len(intro_parts) != 2:
211                logging.warning("Ignoring invalid data from %s: %r",
212                             client.getpeername(), intro)
213                client.close()
214                return
215            tag, name = intro_parts
216
217            logging.info("new client tag=%s, name=%s", tag, name)
218
219            # Ok, we know who is trying to attach.  Confirm that
220            # they are coming to the same meeting.  Also, everyone
221            # should be using a unique handle (their IP address).
222            # If we see a duplicate, something _bad_ has happened
223            # so drop them now.
224            if self._tag != tag:
225                logging.warning("client arriving for the wrong barrier: %s != %s",
226                             self._tag, tag)
227                client.settimeout(5)
228                client.send("!tag")
229                client.close()
230                return
231            elif name in self._waiting:
232                logging.warning("duplicate client")
233                client.settimeout(5)
234                client.send("!dup")
235                client.close()
236                return
237
238            # Acknowledge the client
239            client.send("wait")
240
241        except socket.timeout:
242            # This is nominally an error, but as we do not know
243            # who that was we cannot do anything sane other
244            # than report it and let the normal timeout kill
245            # us when that's appropriate.
246            logging.warning("client handshake timeout: (%s:%d)",
247                         addr[0], addr[1])
248            client.close()
249            return
250
251        logging.info("client now waiting: %s (%s:%d)",
252                     name, addr[0], addr[1])
253
254        # They seem to be valid record them.
255        self._waiting[name] = connection
256        self._seen += 1
257
258
259    def _node_hello(self, connection):
260        (client, addr) = connection
261        name = None
262
263        client.settimeout(5)
264        try:
265            client.send(self._tag + " " + self._hostid)
266
267            reply = client.recv(4)
268            reply = reply.strip(b"\r\n")
269            logging.info("main said: %s", reply)
270            # Confirm the main accepted the connection.
271            if reply != "wait":
272                logging.warning("Bad connection request to main")
273                client.close()
274                return
275
276        except socket.timeout:
277            # This is nominally an error, but as we do not know
278            # who that was we cannot do anything sane other
279            # than report it and let the normal timeout kill
280            # us when that's appropriate.
281            logging.error("main handshake timeout: (%s:%d)",
282                          addr[0], addr[1])
283            client.close()
284            return
285
286        logging.info("node now waiting: (%s:%d)", addr[0], addr[1])
287
288        # They seem to be valid record them.
289        self._waiting[self._hostid] = connection
290        self._seen = 1
291
292
293    def _main_release(self):
294        # Check everyone is still there, that they have not
295        # crashed or disconnected in the meantime.
296        allpresent = True
297        abort = self._abort
298        for name in self._waiting:
299            (client, addr) = self._waiting[name]
300
301            logging.info("checking client present: %s", name)
302
303            client.settimeout(5)
304            reply = 'none'
305            try:
306                client.send("ping")
307                reply = client.recv(1024)
308            except socket.timeout:
309                logging.warning("ping/pong timeout: %s", name)
310                pass
311
312            if reply == 'abrt':
313                logging.warning("Client %s requested abort", name)
314                abort = True
315            elif reply != "pong":
316                allpresent = False
317
318        if not allpresent:
319            raise error.BarrierError("main lost client")
320
321        if abort:
322            logging.info("Aborting the clients")
323            msg = 'abrt'
324        else:
325            logging.info("Releasing clients")
326            msg = 'rlse'
327
328        # If every ones checks in then commit the release.
329        for name in self._waiting:
330            (client, addr) = self._waiting[name]
331
332            client.settimeout(5)
333            try:
334                client.send(msg)
335            except socket.timeout:
336                logging.warning("release timeout: %s", name)
337                pass
338
339        if abort:
340            raise BarrierAbortError("Client requested abort")
341
342
343    def _waiting_close(self):
344        # Either way, close out all the clients.  If we have
345        # not released them then they know to abort.
346        for name in self._waiting:
347            (client, addr) = self._waiting[name]
348
349            logging.info("closing client: %s", name)
350
351            try:
352                client.close()
353            except:
354                pass
355
356
357    def _run_server(self, is_main):
358        server = self._server or listen_server(port=self._port)
359        failed = 0
360        try:
361            while True:
362                try:
363                    # Wait for callers welcoming each.
364                    server.socket.settimeout(self._remaining())
365                    connection = server.socket.accept()
366                    if is_main:
367                        self._main_welcome(connection)
368                    else:
369                        self._node_hello(connection)
370                except socket.timeout:
371                    logging.warning("timeout waiting for remaining clients")
372                    pass
373
374                if is_main:
375                    # Check if everyone is here.
376                    logging.info("main seen %d of %d",
377                                 self._seen, len(self._members))
378                    if self._seen == len(self._members):
379                        self._main_release()
380                        break
381                else:
382                    # Check if main connected.
383                    if self._seen:
384                        logging.info("node connected to main")
385                        self._node_wait()
386                        break
387        finally:
388            self._waiting_close()
389            # if we created the listening_server in the beginning of this
390            # function then close the listening socket here
391            if not self._server:
392                server.close()
393
394
395    def _run_client(self, is_main):
396        while self._remaining() is None or self._remaining() > 0:
397            try:
398                remote = socket.socket(socket.AF_INET,
399                        socket.SOCK_STREAM)
400                remote.settimeout(30)
401                if is_main:
402                    # Connect to all node.
403                    host = _get_host_from_id(self._members[self._seen])
404                    logging.info("calling node: %s", host)
405                    connection = (remote, (host, self._port))
406                    remote.connect(connection[1])
407                    self._main_welcome(connection)
408                else:
409                    # Just connect to the main.
410                    host = _get_host_from_id(self._mainid)
411                    logging.info("calling main")
412                    connection = (remote, (host, self._port))
413                    remote.connect(connection[1])
414                    self._node_hello(connection)
415            except socket.timeout:
416                logging.warning("timeout calling host, retry")
417                sleep(10)
418                pass
419            except socket.error as err:
420                (code, str) = err
421                if (code != errno.ECONNREFUSED and
422                    code != errno.ETIMEDOUT):
423                    raise
424                sleep(10)
425
426            if is_main:
427                # Check if everyone is here.
428                logging.info("main seen %d of %d",
429                             self._seen, len(self._members))
430                if self._seen == len(self._members):
431                    self._main_release()
432                    break
433            else:
434                # Check if main connected.
435                if self._seen:
436                    logging.info("node connected to main")
437                    self._node_wait()
438                    break
439
440        self._waiting_close()
441
442
443    def _node_wait(self):
444        remote = self._waiting[self._hostid][0]
445        mode = "wait"
446        while True:
447            # All control messages are the same size to allow
448            # us to split individual messages easily.
449            remote.settimeout(self._remaining())
450            reply = remote.recv(4)
451            if not reply:
452                break
453
454            reply = reply.strip("\r\n")
455            logging.info("main said: %s", reply)
456
457            mode = reply
458            if reply == "ping":
459                # Ensure we have sufficient time for the
460                # ping/pong/rlse cyle to complete normally.
461                self._update_timeout(10 + 10 * len(self._members))
462
463                if self._abort:
464                    msg = "abrt"
465                else:
466                    msg = "pong"
467                logging.info(msg)
468                remote.settimeout(self._remaining())
469                remote.send(msg)
470
471            elif reply == "rlse" or reply == "abrt":
472                # Ensure we have sufficient time for the
473                # ping/pong/rlse cyle to complete normally.
474                self._update_timeout(10 + 10 * len(self._members))
475
476                logging.info("was released, waiting for close")
477
478        if mode == "rlse":
479            pass
480        elif mode == "wait":
481            raise error.BarrierError("main abort -- barrier timeout")
482        elif mode == "ping":
483            raise error.BarrierError("main abort -- client lost")
484        elif mode == "!tag":
485            raise error.BarrierError("main abort -- incorrect tag")
486        elif mode == "!dup":
487            raise error.BarrierError("main abort -- duplicate client")
488        elif mode == "abrt":
489            raise BarrierAbortError("Client requested abort")
490        else:
491            raise error.BarrierError("main handshake failure: " + mode)
492
493
494    def rendezvous(self, *hosts, **dargs):
495        # if called with abort=True, this will raise an exception
496        # on all the clients.
497        self._start_time = time()
498        self._members = list(hosts)
499        self._members.sort()
500        self._mainid = self._members.pop(0)
501        self._abort = dargs.get('abort', False)
502
503        logging.info("mainid: %s", self._mainid)
504        if self._abort:
505            logging.debug("%s is aborting", self._hostid)
506        if not len(self._members):
507            logging.info("No other members listed.")
508            return
509        logging.info("members: %s", ",".join(self._members))
510
511        self._seen = 0
512        self._waiting = {}
513
514        # Figure out who is the main in this barrier.
515        if self._hostid == self._mainid:
516            logging.info("selected as main")
517            self._run_server(is_main=True)
518        else:
519            logging.info("selected as node")
520            self._run_client(is_main=False)
521
522
523    def rendezvous_servers(self, mainid, *hosts, **dargs):
524        # if called with abort=True, this will raise an exception
525        # on all the clients.
526        self._start_time = time()
527        self._members = list(hosts)
528        self._members.sort()
529        self._mainid = mainid
530        self._abort = dargs.get('abort', False)
531
532        logging.info("mainid: %s", self._mainid)
533        if not len(self._members):
534            logging.info("No other members listed.")
535            return
536        logging.info("members: %s", ",".join(self._members))
537
538        self._seen = 0
539        self._waiting = {}
540
541        # Figure out who is the main in this barrier.
542        if self._hostid == self._mainid:
543            logging.info("selected as main")
544            self._run_client(is_main=True)
545        else:
546            logging.info("selected as node")
547            self._run_server(is_main=False)
548