1from __future__ import print_function
2
3import base64
4import contextlib
5import copy
6import email.utils
7import functools
8import gzip
9import hashlib
10import httplib2
11import os
12import random
13import re
14import shutil
15import six
16import socket
17import ssl
18import struct
19import sys
20import threading
21import time
22import traceback
23import zlib
24from six.moves import http_client, queue
25
26
27DUMMY_URL = "http://127.0.0.1:1"
28DUMMY_HTTPS_URL = "https://127.0.0.1:2"
29
30tls_dir = os.path.join(os.path.dirname(__file__), "tls")
31CA_CERTS = os.path.join(tls_dir, "ca.pem")
32CA_UNUSED_CERTS = os.path.join(tls_dir, "ca_unused.pem")
33CLIENT_PEM = os.path.join(tls_dir, "client.pem")
34CLIENT_ENCRYPTED_PEM = os.path.join(tls_dir, "client_encrypted.pem")
35SERVER_PEM = os.path.join(tls_dir, "server.pem")
36SERVER_CHAIN = os.path.join(tls_dir, "server_chain.pem")
37
38
39@contextlib.contextmanager
40def assert_raises(exc_type):
41    def _name(t):
42        return getattr(t, "__name__", None) or str(t)
43
44    if not isinstance(exc_type, tuple):
45        exc_type = (exc_type,)
46    names = ", ".join(map(_name, exc_type))
47
48    try:
49        yield
50    except exc_type:
51        pass
52    else:
53        assert False, "Expected exception(s) {0}".format(names)
54
55
56class BufferedReader(object):
57    """io.BufferedReader with \r\n support
58    """
59
60    def __init__(self, sock):
61        self._buf = b""
62        self._end = False
63        self._newline = b"\r\n"
64        self._sock = sock
65        if isinstance(sock, bytes):
66            self._sock = None
67            self._buf = sock
68
69    def _fill(self, target=1, more=None, untilend=False):
70        if more:
71            target = len(self._buf) + more
72        while untilend or (len(self._buf) < target):
73            # crutch to enable HttpRequest.from_bytes
74            if self._sock is None:
75                chunk = b""
76            else:
77                chunk = self._sock.recv(8 << 10)
78            # print("!!! recv", chunk)
79            if not chunk:
80                self._end = True
81                if untilend:
82                    return
83                else:
84                    raise EOFError
85            self._buf += chunk
86
87    def peek(self, size):
88        self._fill(target=size)
89        return self._buf[:size]
90
91    def read(self, size):
92        self._fill(target=size)
93        chunk, self._buf = self._buf[:size], self._buf[size:]
94        return chunk
95
96    def readall(self):
97        self._fill(untilend=True)
98        chunk, self._buf = self._buf, b""
99        return chunk
100
101    def readline(self):
102        while True:
103            i = self._buf.find(self._newline)
104            if i >= 0:
105                break
106            self._fill(more=1)
107        inext = i + len(self._newline)
108        line, self._buf = self._buf[:inext], self._buf[inext:]
109        return line
110
111
112def parse_http_message(kind, buf):
113    if buf._end:
114        return None
115    try:
116        start_line = buf.readline()
117    except EOFError:
118        return None
119    msg = kind()
120    msg.raw = start_line
121    if kind is HttpRequest:
122        assert re.match(
123            br".+ HTTP/\d\.\d\r\n$", start_line
124        ), "Start line does not look like HTTP request: " + repr(start_line)
125        msg.method, msg.uri, msg.proto = start_line.rstrip().decode().split(" ", 2)
126        assert msg.proto.startswith("HTTP/"), repr(start_line)
127    elif kind is HttpResponse:
128        assert re.match(
129            br"^HTTP/\d\.\d \d+ .+\r\n$", start_line
130        ), "Start line does not look like HTTP response: " + repr(start_line)
131        msg.proto, msg.status, msg.reason = start_line.rstrip().decode().split(" ", 2)
132        msg.status = int(msg.status)
133        assert msg.proto.startswith("HTTP/"), repr(start_line)
134    else:
135        raise Exception("Use HttpRequest or HttpResponse .from_{bytes,buffered}")
136    msg.version = msg.proto[5:]
137
138    while True:
139        line = buf.readline()
140        msg.raw += line
141        line = line.rstrip()
142        if not line:
143            break
144        t = line.decode().split(":", 1)
145        msg.headers[t[0].lower()] = t[1].lstrip()
146
147    content_length_string = msg.headers.get("content-length", "")
148    if content_length_string.isdigit():
149        content_length = int(content_length_string)
150        msg.body = msg.body_raw = buf.read(content_length)
151    elif msg.headers.get("transfer-encoding") == "chunked":
152        raise NotImplemented
153    elif msg.version == "1.0":
154        msg.body = msg.body_raw = buf.readall()
155    else:
156        msg.body = msg.body_raw = b""
157
158    msg.raw += msg.body_raw
159    return msg
160
161
162class HttpMessage(object):
163    def __init__(self):
164        self.headers = {}
165
166    @classmethod
167    def from_bytes(cls, bs):
168        buf = BufferedReader(bs)
169        return parse_http_message(cls, buf)
170
171    @classmethod
172    def from_buffered(cls, buf):
173        return parse_http_message(cls, buf)
174
175    def __repr__(self):
176        return "{} {}".format(self.__class__, repr(vars(self)))
177
178
179class HttpRequest(HttpMessage):
180    pass
181
182
183class HttpResponse(HttpMessage):
184    pass
185
186
187class MockResponse(six.BytesIO):
188    def __init__(self, body, **kwargs):
189        six.BytesIO.__init__(self, body)
190        self.headers = kwargs
191
192    def items(self):
193        return self.headers.items()
194
195    def iteritems(self):
196        return six.iteritems(self.headers)
197
198
199class MockHTTPConnection(object):
200    """This class is just a mock of httplib.HTTPConnection used for testing
201    """
202
203    def __init__(
204        self,
205        host,
206        port=None,
207        key_file=None,
208        cert_file=None,
209        strict=None,
210        timeout=None,
211        proxy_info=None,
212    ):
213        self.host = host
214        self.port = port
215        self.timeout = timeout
216        self.log = ""
217        self.sock = None
218
219    def set_debuglevel(self, level):
220        pass
221
222    def connect(self):
223        "Connect to a host on a given port."
224        pass
225
226    def close(self):
227        pass
228
229    def request(self, method, request_uri, body, headers):
230        pass
231
232    def getresponse(self):
233        return MockResponse(b"the body", status="200")
234
235
236class MockHTTPBadStatusConnection(object):
237    """Mock of httplib.HTTPConnection that raises BadStatusLine.
238    """
239
240    num_calls = 0
241
242    def __init__(
243        self,
244        host,
245        port=None,
246        key_file=None,
247        cert_file=None,
248        strict=None,
249        timeout=None,
250        proxy_info=None,
251    ):
252        self.host = host
253        self.port = port
254        self.timeout = timeout
255        self.log = ""
256        self.sock = None
257        MockHTTPBadStatusConnection.num_calls = 0
258
259    def set_debuglevel(self, level):
260        pass
261
262    def connect(self):
263        pass
264
265    def close(self):
266        pass
267
268    def request(self, method, request_uri, body, headers):
269        pass
270
271    def getresponse(self):
272        MockHTTPBadStatusConnection.num_calls += 1
273        raise http_client.BadStatusLine("")
274
275
276@contextlib.contextmanager
277def server_socket(fun, request_count=1, timeout=5, scheme="", tls=None):
278    """Base socket server for tests.
279    Likely you want to use server_request or other higher level helpers.
280    All arguments except fun can be passed to other server_* helpers.
281
282    :param fun: fun(client_sock, tick) called after successful accept().
283    :param request_count: test succeeds after exactly this number of requests, triggered by tick(request)
284    :param timeout: seconds.
285    :param scheme: affects yielded value
286        "" - build normal http/https URI.
287        string - build normal URI using supplied scheme.
288        None - yield (addr, port) tuple.
289    :param tls:
290        None (default) - plain HTTP.
291        True - HTTPS with reasonable defaults. Likely you want httplib2.Http(ca_certs=tests.CA_CERTS)
292        string - path to custom server cert+key PEM file.
293        callable - function(context, listener, skip_errors) -> ssl_wrapped_listener
294    """
295    gresult = [None]
296    gcounter = [0]
297    tls_skip_errors = [
298        "TLSV1_ALERT_UNKNOWN_CA",
299    ]
300
301    def tick(request):
302        gcounter[0] += 1
303        keep = True
304        keep &= gcounter[0] < request_count
305        if request is not None:
306            keep &= request.headers.get("connection", "").lower() != "close"
307        return keep
308
309    def server_socket_thread(srv):
310        try:
311            while gcounter[0] < request_count:
312                try:
313                    client, _ = srv.accept()
314                except ssl.SSLError as e:
315                    if e.reason in tls_skip_errors:
316                        return
317                    raise
318
319                try:
320                    client.settimeout(timeout)
321                    fun(client, tick)
322                finally:
323                    try:
324                        client.shutdown(socket.SHUT_RDWR)
325                    except (IOError, socket.error):
326                        pass
327                    # FIXME: client.close() introduces connection reset by peer
328                    # at least in other/connection_close test
329                    # should not be a problem since socket would close upon garbage collection
330            if gcounter[0] > request_count:
331                gresult[0] = Exception(
332                    "Request count expected={0} actual={1}".format(
333                        request_count, gcounter[0]
334                    )
335                )
336        except Exception as e:
337            # traceback.print_exc caused IOError: concurrent operation on sys.stderr.close() under setup.py test
338            print(traceback.format_exc(), file=sys.stderr)
339            gresult[0] = e
340
341    bind_hostname = "localhost"
342    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
343    server.bind((bind_hostname, 0))
344    try:
345        server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
346    except socket.error as ex:
347        print("non critical error on SO_REUSEADDR", ex)
348    server.listen(10)
349    server.settimeout(timeout)
350    server_port = server.getsockname()[1]
351    if tls is True:
352        tls = SERVER_CHAIN
353    if tls:
354        context = ssl_context()
355        if callable(tls):
356            context.load_cert_chain(SERVER_CHAIN)
357            server = tls(context, server, tls_skip_errors)
358        else:
359            context.load_cert_chain(tls)
360            server = context.wrap_socket(server, server_side=True)
361    if scheme == "":
362        scheme = "https" if tls else "http"
363
364    t = threading.Thread(target=server_socket_thread, args=(server,))
365    t.daemon = True
366    t.start()
367    if scheme is None:
368        yield (bind_hostname, server_port)
369    else:
370        yield u"{scheme}://{host}:{port}/".format(scheme=scheme, host=bind_hostname, port=server_port)
371    server.close()
372    t.join()
373    if gresult[0] is not None:
374        raise gresult[0]
375
376
377def server_yield(fun, **kwargs):
378    q = queue.Queue(1)
379    g = fun(q.get)
380
381    def server_yield_socket_handler(sock, tick):
382        buf = BufferedReader(sock)
383        i = 0
384        while True:
385            request = HttpRequest.from_buffered(buf)
386            if request is None:
387                break
388            i += 1
389            request.client_sock = sock
390            request.number = i
391            q.put(request)
392            response = six.next(g)
393            sock.sendall(response)
394            request.client_sock = None
395            if not tick(request):
396                break
397
398    return server_socket(server_yield_socket_handler, **kwargs)
399
400
401def server_request(request_handler, **kwargs):
402    def server_request_socket_handler(sock, tick):
403        buf = BufferedReader(sock)
404        i = 0
405        while True:
406            request = HttpRequest.from_buffered(buf)
407            if request is None:
408                break
409            # print("--- debug request\n" + request.raw.decode("ascii", "replace"))
410            i += 1
411            request.client_sock = sock
412            request.number = i
413            response = request_handler(request=request)
414            # print("--- debug response\n" + response.decode("ascii", "replace"))
415            sock.sendall(response)
416            request.client_sock = None
417            if not tick(request):
418                break
419
420    return server_socket(server_request_socket_handler, **kwargs)
421
422
423def server_const_bytes(response_content, **kwargs):
424    return server_request(lambda request: response_content, **kwargs)
425
426
427_http_kwargs = (
428    "proto",
429    "status",
430    "headers",
431    "body",
432    "add_content_length",
433    "add_date",
434    "add_etag",
435    "undefined_body_length",
436)
437
438
439def http_response_bytes(
440    proto="HTTP/1.1",
441    status="200 OK",
442    headers=None,
443    body=b"",
444    add_content_length=True,
445    add_date=False,
446    add_etag=False,
447    undefined_body_length=False,
448    **kwargs
449):
450    if undefined_body_length:
451        add_content_length = False
452    if headers is None:
453        headers = {}
454    if add_content_length:
455        headers.setdefault("content-length", str(len(body)))
456    if add_date:
457        headers.setdefault("date", email.utils.formatdate())
458    if add_etag:
459        headers.setdefault("etag", '"{0}"'.format(hashlib.md5(body).hexdigest()))
460    header_string = "".join("{0}: {1}\r\n".format(k, v) for k, v in headers.items())
461    if (
462        not undefined_body_length
463        and proto != "HTTP/1.0"
464        and "content-length" not in headers
465    ):
466        raise Exception(
467            "httplib2.tests.http_response_bytes: client could not figure response body length"
468        )
469    if str(status).isdigit():
470        status = "{} {}".format(status, http_client.responses[status])
471    response = (
472        "{proto} {status}\r\n{headers}\r\n".format(
473            proto=proto, status=status, headers=header_string
474        ).encode()
475        + body
476    )
477    return response
478
479
480def make_http_reflect(**kwargs):
481    assert "body" not in kwargs, "make_http_reflect will overwrite response " "body"
482
483    def fun(request):
484        kw = copy.deepcopy(kwargs)
485        kw["body"] = request.raw
486        response = http_response_bytes(**kw)
487        return response
488
489    return fun
490
491
492def server_route(routes, **kwargs):
493    response_404 = http_response_bytes(status="404 Not Found")
494    response_wildcard = routes.get("")
495
496    def handler(request):
497        target = routes.get(request.uri, response_wildcard) or response_404
498        if callable(target):
499            response = target(request=request)
500        else:
501            response = target
502        return response
503
504    return server_request(handler, **kwargs)
505
506
507def server_const_http(**kwargs):
508    response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs}
509    response = http_response_bytes(**response_kwargs)
510    return server_const_bytes(response, **kwargs)
511
512
513def server_list_http(responses, **kwargs):
514    i = iter(responses)
515
516    def handler(request):
517        return next(i)
518
519    kwargs.setdefault("request_count", len(responses))
520    return server_request(handler, **kwargs)
521
522
523def server_reflect(**kwargs):
524    response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs}
525    http_handler = make_http_reflect(**response_kwargs)
526    return server_request(http_handler, **kwargs)
527
528
529def http_parse_auth(s):
530    """https://tools.ietf.org/html/rfc7235#section-2.1
531    """
532    scheme, rest = s.split(" ", 1)
533    result = {}
534    while True:
535        m = httplib2.WWW_AUTH_RELAXED.search(rest)
536        if not m:
537            break
538        if len(m.groups()) == 3:
539            key, value, rest = m.groups()
540            result[key.lower()] = httplib2.UNQUOTE_PAIRS.sub(r"\1", value)
541    return result
542
543
544def store_request_response(out):
545    def wrapper(fun):
546        @functools.wraps(fun)
547        def wrapped(request, *a, **kw):
548            response_bytes = fun(request, *a, **kw)
549            if out is not None:
550                response = HttpResponse.from_bytes(response_bytes)
551                out.append((request, response))
552            return response_bytes
553
554        return wrapped
555
556    return wrapper
557
558
559def http_reflect_with_auth(
560    allow_scheme, allow_credentials, out_renew_nonce=None, out_requests=None
561):
562    """allow_scheme - 'basic', 'digest', etc allow_credentials - sequence of ('name', 'password') out_renew_nonce - None | [function]
563
564        Way to return nonce renew function to caller.
565        Kind of `out` parameter in some programming languages.
566        Allows to keep same signature for all handler builder functions.
567    out_requests - None | []
568        If set to list, every parsed request will be appended here.
569    """
570    glastnc = [None]
571    gnextnonce = [None]
572    gserver_nonce = [gen_digest_nonce(salt=b"n")]
573    realm = "httplib2 test"
574    server_opaque = gen_digest_nonce(salt=b"o")
575
576    def renew_nonce():
577        if gnextnonce[0]:
578            assert False, (
579                "previous nextnonce was not used, probably bug in " "test code"
580            )
581        gnextnonce[0] = gen_digest_nonce()
582        return gserver_nonce[0], gnextnonce[0]
583
584    if out_renew_nonce:
585        out_renew_nonce[0] = renew_nonce
586
587    def deny(**kwargs):
588        nonce_stale = kwargs.pop("nonce_stale", False)
589        if nonce_stale:
590            kwargs.setdefault("body", b"nonce stale")
591        if allow_scheme == "basic":
592            authenticate = 'basic realm="{realm}"'.format(realm=realm)
593        elif allow_scheme == "digest":
594            authenticate = (
595                'digest realm="{realm}", qop="auth"'
596                + ', nonce="{nonce}", opaque="{opaque}"'
597                + (", stale=true" if nonce_stale else "")
598            ).format(realm=realm, nonce=gserver_nonce[0], opaque=server_opaque)
599        else:
600            raise Exception("unknown allow_scheme={0}".format(allow_scheme))
601        deny_headers = {"www-authenticate": authenticate}
602        kwargs.setdefault("status", 401)
603        # supplied headers may overwrite generated ones
604        deny_headers.update(kwargs.get("headers", {}))
605        kwargs["headers"] = deny_headers
606        kwargs.setdefault("body", b"HTTP authorization required")
607        return http_response_bytes(**kwargs)
608
609    @store_request_response(out_requests)
610    def http_reflect_with_auth_handler(request):
611        auth_header = request.headers.get("authorization", "")
612        if not auth_header:
613            return deny()
614        if " " not in auth_header:
615            return http_response_bytes(
616                status=400, body=b"authorization header syntax error"
617            )
618        scheme, data = auth_header.split(" ", 1)
619        scheme = scheme.lower()
620        if scheme != allow_scheme:
621            return deny(body=b"must use different auth scheme")
622        if scheme == "basic":
623            decoded = base64.b64decode(data).decode()
624            username, password = decoded.split(":", 1)
625            if (username, password) in allow_credentials:
626                return make_http_reflect()(request)
627            else:
628                return deny(body=b"supplied credentials are not allowed")
629        elif scheme == "digest":
630            server_nonce_old = gserver_nonce[0]
631            nextnonce = gnextnonce[0]
632            if nextnonce:
633                # server decided to change nonce, in this case, guided by caller test code
634                gserver_nonce[0] = nextnonce
635                gnextnonce[0] = None
636            server_nonce_current = gserver_nonce[0]
637            auth_info = http_parse_auth(data)
638            client_cnonce = auth_info.get("cnonce", "")
639            client_nc = auth_info.get("nc", "")
640            client_nonce = auth_info.get("nonce", "")
641            client_opaque = auth_info.get("opaque", "")
642            client_qop = auth_info.get("qop", "auth").strip('"')
643
644            # TODO: auth_info.get('algorithm', 'md5')
645            hasher = hashlib.md5
646
647            # TODO: client_qop auth-int
648            ha2 = hasher(":".join((request.method, request.uri)).encode()).hexdigest()
649
650            if client_nonce != server_nonce_current:
651                if client_nonce == server_nonce_old:
652                    return deny(nonce_stale=True)
653                return deny(body=b"invalid nonce")
654            if not client_nc:
655                return deny(body=b"auth-info nc missing")
656            if client_opaque != server_opaque:
657                return deny(
658                    body="auth-info opaque mismatch expected={} actual={}".format(
659                        server_opaque, client_opaque
660                    ).encode()
661                )
662            for allow_username, allow_password in allow_credentials:
663                ha1 = hasher(
664                    ":".join((allow_username, realm, allow_password)).encode()
665                ).hexdigest()
666                allow_response = hasher(
667                    ":".join(
668                        (ha1, client_nonce, client_nc, client_cnonce, client_qop, ha2)
669                    ).encode()
670                ).hexdigest()
671                rspauth_ha2 = hasher(":{}".format(request.uri).encode()).hexdigest()
672                rspauth = hasher(
673                    ":".join(
674                        (
675                            ha1,
676                            client_nonce,
677                            client_nc,
678                            client_cnonce,
679                            client_qop,
680                            rspauth_ha2,
681                        )
682                    ).encode()
683                ).hexdigest()
684                if auth_info.get("response", "") == allow_response:
685                    # TODO: fix or remove doubtful comment
686                    # do we need to save nc only on success?
687                    glastnc[0] = client_nc
688                    allow_headers = {
689                        "authentication-info": " ".join(
690                            (
691                                'nextnonce="{}"'.format(nextnonce) if nextnonce else "",
692                                "qop={}".format(client_qop),
693                                'rspauth="{}"'.format(rspauth),
694                                'cnonce="{}"'.format(client_cnonce),
695                                "nc={}".format(client_nc),
696                            )
697                        ).strip()
698                    }
699                    return make_http_reflect(headers=allow_headers)(request)
700            return deny(body=b"supplied credentials are not allowed")
701        else:
702            return http_response_bytes(
703                status=400,
704                body="unknown authorization scheme={0}".format(scheme).encode(),
705            )
706
707    return http_reflect_with_auth_handler
708
709
710def get_cache_path():
711    default = "./_httplib2_test_cache"
712    path = os.environ.get("httplib2_test_cache_path") or default
713    if os.path.exists(path):
714        shutil.rmtree(path)
715    return path
716
717
718def gen_digest_nonce(salt=b""):
719    t = struct.pack(">Q", int(time.time() * 1e9))
720    return base64.b64encode(t + b":" + hashlib.sha1(t + salt).digest()).decode()
721
722
723def gen_password():
724    length = random.randint(8, 64)
725    return "".join(six.unichr(random.randint(0, 127)) for _ in range(length))
726
727
728def gzip_compress(bs):
729    # gzipobj = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
730    # result = gzipobj.compress(text) + gzipobj.flush()
731    buf = six.BytesIO()
732    gf = gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6)
733    gf.write(bs)
734    gf.close()
735    return buf.getvalue()
736
737
738def gzip_decompress(bs):
739    return zlib.decompress(bs, zlib.MAX_WBITS | 16)
740
741
742def deflate_compress(bs):
743    do = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS)
744    return do.compress(bs) + do.flush()
745
746
747def deflate_decompress(bs):
748    return zlib.decompress(bs, -zlib.MAX_WBITS)
749
750
751def ssl_context(protocol=None):
752    """Workaround for old SSLContext() required protocol argument.
753    """
754    if sys.version_info < (3, 5, 3):
755        return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
756    return ssl.SSLContext()
757