1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import collections
9import json
10import math
11import os
12import re
13from contextlib import contextmanager
14
15import pytest
16
17import six
18
19from cryptography.exceptions import UnsupportedAlgorithm
20
21import cryptography_vectors
22
23
24HashVector = collections.namedtuple("HashVector", ["message", "digest"])
25KeyedHashVector = collections.namedtuple(
26    "KeyedHashVector", ["message", "digest", "key"]
27)
28
29
30def check_backend_support(backend, item):
31    for mark in item.node.iter_markers("supported"):
32        if not mark.kwargs["only_if"](backend):
33            pytest.skip("{0} ({1})".format(
34                mark.kwargs["skip_message"], backend
35            ))
36
37
38@contextmanager
39def raises_unsupported_algorithm(reason):
40    with pytest.raises(UnsupportedAlgorithm) as exc_info:
41        yield exc_info
42
43    assert exc_info.value._reason is reason
44
45
46def load_vectors_from_file(filename, loader, mode="r"):
47    with cryptography_vectors.open_vector_file(filename, mode) as vector_file:
48        return loader(vector_file)
49
50
51def load_nist_vectors(vector_data):
52    test_data = None
53    data = []
54
55    for line in vector_data:
56        line = line.strip()
57
58        # Blank lines, comments, and section headers are ignored
59        if not line or line.startswith("#") or (line.startswith("[") and
60                                                line.endswith("]")):
61            continue
62
63        if line.strip() == "FAIL":
64            test_data["fail"] = True
65            continue
66
67        # Build our data using a simple Key = Value format
68        name, value = [c.strip() for c in line.split("=")]
69
70        # Some tests (PBKDF2) contain \0, which should be interpreted as a
71        # null character rather than literal.
72        value = value.replace("\\0", "\0")
73
74        # COUNT is a special token that indicates a new block of data
75        if name.upper() == "COUNT":
76            test_data = {}
77            data.append(test_data)
78            continue
79        # For all other tokens we simply want the name, value stored in
80        # the dictionary
81        else:
82            test_data[name.lower()] = value.encode("ascii")
83
84    return data
85
86
87def load_cryptrec_vectors(vector_data):
88    cryptrec_list = []
89
90    for line in vector_data:
91        line = line.strip()
92
93        # Blank lines and comments are ignored
94        if not line or line.startswith("#"):
95            continue
96
97        if line.startswith("K"):
98            key = line.split(" : ")[1].replace(" ", "").encode("ascii")
99        elif line.startswith("P"):
100            pt = line.split(" : ")[1].replace(" ", "").encode("ascii")
101        elif line.startswith("C"):
102            ct = line.split(" : ")[1].replace(" ", "").encode("ascii")
103            # after a C is found the K+P+C tuple is complete
104            # there are many P+C pairs for each K
105            cryptrec_list.append({
106                "key": key,
107                "plaintext": pt,
108                "ciphertext": ct
109            })
110        else:
111            raise ValueError("Invalid line in file '{}'".format(line))
112    return cryptrec_list
113
114
115def load_hash_vectors(vector_data):
116    vectors = []
117    key = None
118    msg = None
119    md = None
120
121    for line in vector_data:
122        line = line.strip()
123
124        if not line or line.startswith("#") or line.startswith("["):
125            continue
126
127        if line.startswith("Len"):
128            length = int(line.split(" = ")[1])
129        elif line.startswith("Key"):
130            # HMAC vectors contain a key attribute. Hash vectors do not.
131            key = line.split(" = ")[1].encode("ascii")
132        elif line.startswith("Msg"):
133            # In the NIST vectors they have chosen to represent an empty
134            # string as hex 00, which is of course not actually an empty
135            # string. So we parse the provided length and catch this edge case.
136            msg = line.split(" = ")[1].encode("ascii") if length > 0 else b""
137        elif line.startswith("MD") or line.startswith("Output"):
138            md = line.split(" = ")[1]
139            # after MD is found the Msg+MD (+ potential key) tuple is complete
140            if key is not None:
141                vectors.append(KeyedHashVector(msg, md, key))
142                key = None
143                msg = None
144                md = None
145            else:
146                vectors.append(HashVector(msg, md))
147                msg = None
148                md = None
149        else:
150            raise ValueError("Unknown line in hash vector")
151    return vectors
152
153
154def load_pkcs1_vectors(vector_data):
155    """
156    Loads data out of RSA PKCS #1 vector files.
157    """
158    private_key_vector = None
159    public_key_vector = None
160    attr = None
161    key = None
162    example_vector = None
163    examples = []
164    vectors = []
165    for line in vector_data:
166        if (
167            line.startswith("# PSS Example") or
168            line.startswith("# OAEP Example") or
169            line.startswith("# PKCS#1 v1.5")
170        ):
171            if example_vector:
172                for key, value in six.iteritems(example_vector):
173                    hex_str = "".join(value).replace(" ", "").encode("ascii")
174                    example_vector[key] = hex_str
175                examples.append(example_vector)
176
177            attr = None
178            example_vector = collections.defaultdict(list)
179
180        if line.startswith("# Message"):
181            attr = "message"
182            continue
183        elif line.startswith("# Salt"):
184            attr = "salt"
185            continue
186        elif line.startswith("# Seed"):
187            attr = "seed"
188            continue
189        elif line.startswith("# Signature"):
190            attr = "signature"
191            continue
192        elif line.startswith("# Encryption"):
193            attr = "encryption"
194            continue
195        elif (
196            example_vector and
197            line.startswith("# =============================================")
198        ):
199            for key, value in six.iteritems(example_vector):
200                hex_str = "".join(value).replace(" ", "").encode("ascii")
201                example_vector[key] = hex_str
202            examples.append(example_vector)
203            example_vector = None
204            attr = None
205        elif example_vector and line.startswith("#"):
206            continue
207        else:
208            if attr is not None and example_vector is not None:
209                example_vector[attr].append(line.strip())
210                continue
211
212        if (
213            line.startswith("# Example") or
214            line.startswith("# =============================================")
215        ):
216            if key:
217                assert private_key_vector
218                assert public_key_vector
219
220                for key, value in six.iteritems(public_key_vector):
221                    hex_str = "".join(value).replace(" ", "")
222                    public_key_vector[key] = int(hex_str, 16)
223
224                for key, value in six.iteritems(private_key_vector):
225                    hex_str = "".join(value).replace(" ", "")
226                    private_key_vector[key] = int(hex_str, 16)
227
228                private_key_vector["examples"] = examples
229                examples = []
230
231                assert (
232                    private_key_vector['public_exponent'] ==
233                    public_key_vector['public_exponent']
234                )
235
236                assert (
237                    private_key_vector['modulus'] ==
238                    public_key_vector['modulus']
239                )
240
241                vectors.append(
242                    (private_key_vector, public_key_vector)
243                )
244
245            public_key_vector = collections.defaultdict(list)
246            private_key_vector = collections.defaultdict(list)
247            key = None
248            attr = None
249
250        if private_key_vector is None or public_key_vector is None:
251            # Random garbage to defeat CPython's peephole optimizer so that
252            # coverage records correctly: https://bugs.python.org/issue2506
253            1 + 1
254            continue
255
256        if line.startswith("# Private key"):
257            key = private_key_vector
258        elif line.startswith("# Public key"):
259            key = public_key_vector
260        elif line.startswith("# Modulus:"):
261            attr = "modulus"
262        elif line.startswith("# Public exponent:"):
263            attr = "public_exponent"
264        elif line.startswith("# Exponent:"):
265            if key is public_key_vector:
266                attr = "public_exponent"
267            else:
268                assert key is private_key_vector
269                attr = "private_exponent"
270        elif line.startswith("# Prime 1:"):
271            attr = "p"
272        elif line.startswith("# Prime 2:"):
273            attr = "q"
274        elif line.startswith("# Prime exponent 1:"):
275            attr = "dmp1"
276        elif line.startswith("# Prime exponent 2:"):
277            attr = "dmq1"
278        elif line.startswith("# Coefficient:"):
279            attr = "iqmp"
280        elif line.startswith("#"):
281            attr = None
282        else:
283            if key is not None and attr is not None:
284                key[attr].append(line.strip())
285    return vectors
286
287
288def load_rsa_nist_vectors(vector_data):
289    test_data = None
290    p = None
291    salt_length = None
292    data = []
293
294    for line in vector_data:
295        line = line.strip()
296
297        # Blank lines and section headers are ignored
298        if not line or line.startswith("["):
299            continue
300
301        if line.startswith("# Salt len:"):
302            salt_length = int(line.split(":")[1].strip())
303            continue
304        elif line.startswith("#"):
305            continue
306
307        # Build our data using a simple Key = Value format
308        name, value = [c.strip() for c in line.split("=")]
309
310        if name == "n":
311            n = int(value, 16)
312        elif name == "e" and p is None:
313            e = int(value, 16)
314        elif name == "p":
315            p = int(value, 16)
316        elif name == "q":
317            q = int(value, 16)
318        elif name == "SHAAlg":
319            if p is None:
320                test_data = {
321                    "modulus": n,
322                    "public_exponent": e,
323                    "salt_length": salt_length,
324                    "algorithm": value,
325                    "fail": False
326                }
327            else:
328                test_data = {
329                    "modulus": n,
330                    "p": p,
331                    "q": q,
332                    "algorithm": value
333                }
334                if salt_length is not None:
335                    test_data["salt_length"] = salt_length
336            data.append(test_data)
337        elif name == "e" and p is not None:
338            test_data["public_exponent"] = int(value, 16)
339        elif name == "d":
340            test_data["private_exponent"] = int(value, 16)
341        elif name == "Result":
342            test_data["fail"] = value.startswith("F")
343        # For all other tokens we simply want the name, value stored in
344        # the dictionary
345        else:
346            test_data[name.lower()] = value.encode("ascii")
347
348    return data
349
350
351def load_fips_dsa_key_pair_vectors(vector_data):
352    """
353    Loads data out of the FIPS DSA KeyPair vector files.
354    """
355    vectors = []
356    for line in vector_data:
357        line = line.strip()
358
359        if not line or line.startswith("#") or line.startswith("[mod"):
360            continue
361
362        if line.startswith("P"):
363            vectors.append({'p': int(line.split("=")[1], 16)})
364        elif line.startswith("Q"):
365            vectors[-1]['q'] = int(line.split("=")[1], 16)
366        elif line.startswith("G"):
367            vectors[-1]['g'] = int(line.split("=")[1], 16)
368        elif line.startswith("X") and 'x' not in vectors[-1]:
369            vectors[-1]['x'] = int(line.split("=")[1], 16)
370        elif line.startswith("X") and 'x' in vectors[-1]:
371            vectors.append({'p': vectors[-1]['p'],
372                            'q': vectors[-1]['q'],
373                            'g': vectors[-1]['g'],
374                            'x': int(line.split("=")[1], 16)
375                            })
376        elif line.startswith("Y"):
377            vectors[-1]['y'] = int(line.split("=")[1], 16)
378
379    return vectors
380
381
382def load_fips_dsa_sig_vectors(vector_data):
383    """
384    Loads data out of the FIPS DSA SigVer vector files.
385    """
386    vectors = []
387    sha_regex = re.compile(
388        r"\[mod = L=...., N=..., SHA-(?P<sha>1|224|256|384|512)\]"
389    )
390
391    for line in vector_data:
392        line = line.strip()
393
394        if not line or line.startswith("#"):
395            continue
396
397        sha_match = sha_regex.match(line)
398        if sha_match:
399            digest_algorithm = "SHA-{0}".format(sha_match.group("sha"))
400
401        if line.startswith("[mod"):
402            continue
403
404        name, value = [c.strip() for c in line.split("=")]
405
406        if name == "P":
407            vectors.append({'p': int(value, 16),
408                            'digest_algorithm': digest_algorithm})
409        elif name == "Q":
410            vectors[-1]['q'] = int(value, 16)
411        elif name == "G":
412            vectors[-1]['g'] = int(value, 16)
413        elif name == "Msg" and 'msg' not in vectors[-1]:
414            hexmsg = value.strip().encode("ascii")
415            vectors[-1]['msg'] = binascii.unhexlify(hexmsg)
416        elif name == "Msg" and 'msg' in vectors[-1]:
417            hexmsg = value.strip().encode("ascii")
418            vectors.append({'p': vectors[-1]['p'],
419                            'q': vectors[-1]['q'],
420                            'g': vectors[-1]['g'],
421                            'digest_algorithm':
422                            vectors[-1]['digest_algorithm'],
423                            'msg': binascii.unhexlify(hexmsg)})
424        elif name == "X":
425            vectors[-1]['x'] = int(value, 16)
426        elif name == "Y":
427            vectors[-1]['y'] = int(value, 16)
428        elif name == "R":
429            vectors[-1]['r'] = int(value, 16)
430        elif name == "S":
431            vectors[-1]['s'] = int(value, 16)
432        elif name == "Result":
433            vectors[-1]['result'] = value.split("(")[0].strip()
434
435    return vectors
436
437
438# https://tools.ietf.org/html/rfc4492#appendix-A
439_ECDSA_CURVE_NAMES = {
440    "P-192": "secp192r1",
441    "P-224": "secp224r1",
442    "P-256": "secp256r1",
443    "P-384": "secp384r1",
444    "P-521": "secp521r1",
445
446    "K-163": "sect163k1",
447    "K-233": "sect233k1",
448    "K-256": "secp256k1",
449    "K-283": "sect283k1",
450    "K-409": "sect409k1",
451    "K-571": "sect571k1",
452
453    "B-163": "sect163r2",
454    "B-233": "sect233r1",
455    "B-283": "sect283r1",
456    "B-409": "sect409r1",
457    "B-571": "sect571r1",
458}
459
460
461def load_fips_ecdsa_key_pair_vectors(vector_data):
462    """
463    Loads data out of the FIPS ECDSA KeyPair vector files.
464    """
465    vectors = []
466    key_data = None
467    for line in vector_data:
468        line = line.strip()
469
470        if not line or line.startswith("#"):
471            continue
472
473        if line[1:-1] in _ECDSA_CURVE_NAMES:
474            curve_name = _ECDSA_CURVE_NAMES[line[1:-1]]
475
476        elif line.startswith("d = "):
477            if key_data is not None:
478                vectors.append(key_data)
479
480            key_data = {
481                "curve": curve_name,
482                "d": int(line.split("=")[1], 16)
483            }
484
485        elif key_data is not None:
486            if line.startswith("Qx = "):
487                key_data["x"] = int(line.split("=")[1], 16)
488            elif line.startswith("Qy = "):
489                key_data["y"] = int(line.split("=")[1], 16)
490
491    assert key_data is not None
492    vectors.append(key_data)
493
494    return vectors
495
496
497def load_fips_ecdsa_signing_vectors(vector_data):
498    """
499    Loads data out of the FIPS ECDSA SigGen vector files.
500    """
501    vectors = []
502
503    curve_rx = re.compile(
504        r"\[(?P<curve>[PKB]-[0-9]{3}),SHA-(?P<sha>1|224|256|384|512)\]"
505    )
506
507    data = None
508    for line in vector_data:
509        line = line.strip()
510
511        curve_match = curve_rx.match(line)
512        if curve_match:
513            curve_name = _ECDSA_CURVE_NAMES[curve_match.group("curve")]
514            digest_name = "SHA-{0}".format(curve_match.group("sha"))
515
516        elif line.startswith("Msg = "):
517            if data is not None:
518                vectors.append(data)
519
520            hexmsg = line.split("=")[1].strip().encode("ascii")
521
522            data = {
523                "curve": curve_name,
524                "digest_algorithm": digest_name,
525                "message": binascii.unhexlify(hexmsg)
526            }
527
528        elif data is not None:
529            if line.startswith("Qx = "):
530                data["x"] = int(line.split("=")[1], 16)
531            elif line.startswith("Qy = "):
532                data["y"] = int(line.split("=")[1], 16)
533            elif line.startswith("R = "):
534                data["r"] = int(line.split("=")[1], 16)
535            elif line.startswith("S = "):
536                data["s"] = int(line.split("=")[1], 16)
537            elif line.startswith("d = "):
538                data["d"] = int(line.split("=")[1], 16)
539            elif line.startswith("Result = "):
540                data["fail"] = line.split("=")[1].strip()[0] == "F"
541
542    assert data is not None
543    vectors.append(data)
544    return vectors
545
546
547def load_kasvs_dh_vectors(vector_data):
548    """
549    Loads data out of the KASVS key exchange vector data
550    """
551
552    result_rx = re.compile(r"([FP]) \(([0-9]+) -")
553
554    vectors = []
555    data = {
556        "fail_z": False,
557        "fail_agree": False
558    }
559
560    for line in vector_data:
561        line = line.strip()
562
563        if not line or line.startswith("#"):
564            continue
565
566        if line.startswith("P = "):
567            data["p"] = int(line.split("=")[1], 16)
568        elif line.startswith("Q = "):
569            data["q"] = int(line.split("=")[1], 16)
570        elif line.startswith("G = "):
571            data["g"] = int(line.split("=")[1], 16)
572        elif line.startswith("Z = "):
573            z_hex = line.split("=")[1].strip().encode("ascii")
574            data["z"] = binascii.unhexlify(z_hex)
575        elif line.startswith("XstatCAVS = "):
576            data["x1"] = int(line.split("=")[1], 16)
577        elif line.startswith("YstatCAVS = "):
578            data["y1"] = int(line.split("=")[1], 16)
579        elif line.startswith("XstatIUT = "):
580            data["x2"] = int(line.split("=")[1], 16)
581        elif line.startswith("YstatIUT = "):
582            data["y2"] = int(line.split("=")[1], 16)
583        elif line.startswith("Result = "):
584            result_str = line.split("=")[1].strip()
585            match = result_rx.match(result_str)
586
587            if match.group(1) == "F":
588                if int(match.group(2)) in (5, 10):
589                    data["fail_z"] = True
590                else:
591                    data["fail_agree"] = True
592
593            vectors.append(data)
594
595            data = {
596                "p": data["p"],
597                "q": data["q"],
598                "g": data["g"],
599                "fail_z": False,
600                "fail_agree": False
601            }
602
603    return vectors
604
605
606def load_kasvs_ecdh_vectors(vector_data):
607    """
608    Loads data out of the KASVS key exchange vector data
609    """
610
611    curve_name_map = {
612        "P-192": "secp192r1",
613        "P-224": "secp224r1",
614        "P-256": "secp256r1",
615        "P-384": "secp384r1",
616        "P-521": "secp521r1",
617    }
618
619    result_rx = re.compile(r"([FP]) \(([0-9]+) -")
620
621    tags = []
622    sets = {}
623    vectors = []
624
625    # find info in header
626    for line in vector_data:
627        line = line.strip()
628
629        if line.startswith("#"):
630            parm = line.split("Parameter set(s) supported:")
631            if len(parm) == 2:
632                names = parm[1].strip().split()
633                for n in names:
634                    tags.append("[%s]" % n)
635                break
636
637    # Sets Metadata
638    tag = None
639    curve = None
640    for line in vector_data:
641        line = line.strip()
642
643        if not line or line.startswith("#"):
644            continue
645
646        if line in tags:
647            tag = line
648            curve = None
649        elif line.startswith("[Curve selected:"):
650            curve = curve_name_map[line.split(':')[1].strip()[:-1]]
651
652        if tag is not None and curve is not None:
653            sets[tag.strip("[]")] = curve
654            tag = None
655        if len(tags) == len(sets):
656            break
657
658    # Data
659    data = {
660        "CAVS": {},
661        "IUT": {},
662    }
663    tag = None
664    for line in vector_data:
665        line = line.strip()
666
667        if not line or line.startswith("#"):
668            continue
669
670        if line.startswith("["):
671            tag = line.split()[0][1:]
672        elif line.startswith("COUNT = "):
673            data["COUNT"] = int(line.split("=")[1])
674        elif line.startswith("dsCAVS = "):
675            data["CAVS"]["d"] = int(line.split("=")[1], 16)
676        elif line.startswith("QsCAVSx = "):
677            data["CAVS"]["x"] = int(line.split("=")[1], 16)
678        elif line.startswith("QsCAVSy = "):
679            data["CAVS"]["y"] = int(line.split("=")[1], 16)
680        elif line.startswith("dsIUT = "):
681            data["IUT"]["d"] = int(line.split("=")[1], 16)
682        elif line.startswith("QsIUTx = "):
683            data["IUT"]["x"] = int(line.split("=")[1], 16)
684        elif line.startswith("QsIUTy = "):
685            data["IUT"]["y"] = int(line.split("=")[1], 16)
686        elif line.startswith("OI = "):
687            data["OI"] = int(line.split("=")[1], 16)
688        elif line.startswith("Z = "):
689            data["Z"] = int(line.split("=")[1], 16)
690        elif line.startswith("DKM = "):
691            data["DKM"] = int(line.split("=")[1], 16)
692        elif line.startswith("Result = "):
693            result_str = line.split("=")[1].strip()
694            match = result_rx.match(result_str)
695
696            if match.group(1) == "F":
697                data["fail"] = True
698            else:
699                data["fail"] = False
700            data["errno"] = int(match.group(2))
701
702            data["curve"] = sets[tag]
703
704            vectors.append(data)
705
706            data = {
707                "CAVS": {},
708                "IUT": {},
709            }
710
711    return vectors
712
713
714def load_x963_vectors(vector_data):
715    """
716    Loads data out of the X9.63 vector data
717    """
718
719    vectors = []
720
721    # Sets Metadata
722    hashname = None
723    vector = {}
724    for line in vector_data:
725        line = line.strip()
726
727        if line.startswith("[SHA"):
728            hashname = line[1:-1]
729            shared_secret_len = 0
730            shared_info_len = 0
731            key_data_len = 0
732        elif line.startswith("[shared secret length"):
733            shared_secret_len = int(line[1:-1].split("=")[1].strip())
734        elif line.startswith("[SharedInfo length"):
735            shared_info_len = int(line[1:-1].split("=")[1].strip())
736        elif line.startswith("[key data length"):
737            key_data_len = int(line[1:-1].split("=")[1].strip())
738        elif line.startswith("COUNT"):
739            count = int(line.split("=")[1].strip())
740            vector["hash"] = hashname
741            vector["count"] = count
742            vector["shared_secret_length"] = shared_secret_len
743            vector["sharedinfo_length"] = shared_info_len
744            vector["key_data_length"] = key_data_len
745        elif line.startswith("Z"):
746            vector["Z"] = line.split("=")[1].strip()
747            assert math.ceil(shared_secret_len / 8) * 2 == len(vector["Z"])
748        elif line.startswith("SharedInfo"):
749            if shared_info_len != 0:
750                vector["sharedinfo"] = line.split("=")[1].strip()
751                silen = len(vector["sharedinfo"])
752                assert math.ceil(shared_info_len / 8) * 2 == silen
753        elif line.startswith("key_data"):
754            vector["key_data"] = line.split("=")[1].strip()
755            assert math.ceil(key_data_len / 8) * 2 == len(vector["key_data"])
756            vectors.append(vector)
757            vector = {}
758
759    return vectors
760
761
762def load_nist_kbkdf_vectors(vector_data):
763    """
764    Load NIST SP 800-108 KDF Vectors
765    """
766    vectors = []
767    test_data = None
768    tag = {}
769
770    for line in vector_data:
771        line = line.strip()
772
773        if not line or line.startswith("#"):
774            continue
775
776        if line.startswith("[") and line.endswith("]"):
777            tag_data = line[1:-1]
778            name, value = [c.strip() for c in tag_data.split("=")]
779            if value.endswith('_BITS'):
780                value = int(value.split('_')[0])
781                tag.update({name.lower(): value})
782                continue
783
784            tag.update({name.lower(): value.lower()})
785        elif line.startswith("COUNT="):
786            test_data = dict()
787            test_data.update(tag)
788            vectors.append(test_data)
789        elif line.startswith("L"):
790            name, value = [c.strip() for c in line.split("=")]
791            test_data[name.lower()] = int(value)
792        else:
793            name, value = [c.strip() for c in line.split("=")]
794            test_data[name.lower()] = value.encode("ascii")
795
796    return vectors
797
798
799def load_ed25519_vectors(vector_data):
800    data = []
801    for line in vector_data:
802        secret_key, public_key, message, signature, _ = line.split(':')
803        # In the vectors the first element is secret key + public key
804        secret_key = secret_key[0:64]
805        # In the vectors the signature section is signature + message
806        signature = signature[0:128]
807        data.append({
808            "secret_key": secret_key,
809            "public_key": public_key,
810            "message": message,
811            "signature": signature
812        })
813    return data
814
815
816def load_nist_ccm_vectors(vector_data):
817    test_data = None
818    section_data = None
819    global_data = {}
820    new_section = False
821    data = []
822
823    for line in vector_data:
824        line = line.strip()
825
826        # Blank lines and comments should be ignored
827        if not line or line.startswith("#"):
828            continue
829
830        # Some of the CCM vectors have global values for this. They are always
831        # at the top before the first section header (see: VADT, VNT, VPT)
832        if line.startswith(("Alen", "Plen", "Nlen", "Tlen")):
833            name, value = [c.strip() for c in line.split("=")]
834            global_data[name.lower()] = int(value)
835            continue
836
837        # section headers contain length data we might care about
838        if line.startswith("["):
839            new_section = True
840            section_data = {}
841            section = line[1:-1]
842            items = [c.strip() for c in section.split(",")]
843            for item in items:
844                name, value = [c.strip() for c in item.split("=")]
845                section_data[name.lower()] = int(value)
846            continue
847
848        name, value = [c.strip() for c in line.split("=")]
849
850        if name.lower() in ("key", "nonce") and new_section:
851            section_data[name.lower()] = value.encode("ascii")
852            continue
853
854        new_section = False
855
856        # Payload is sometimes special because these vectors are absurd. Each
857        # example may or may not have a payload. If it does not then the
858        # previous example's payload should be used. We accomplish this by
859        # writing it into the section_data. Because we update each example
860        # with the section data it will be overwritten if a new payload value
861        # is present. NIST should be ashamed of their vector creation.
862        if name.lower() == "payload":
863            section_data[name.lower()] = value.encode("ascii")
864
865        # Result is a special token telling us if the test should pass/fail.
866        # This is only present in the DVPT CCM tests
867        if name.lower() == "result":
868            if value.lower() == "pass":
869                test_data["fail"] = False
870            else:
871                test_data["fail"] = True
872            continue
873
874        # COUNT is a special token that indicates a new block of data
875        if name.lower() == "count":
876            test_data = {}
877            test_data.update(global_data)
878            test_data.update(section_data)
879            data.append(test_data)
880            continue
881        # For all other tokens we simply want the name, value stored in
882        # the dictionary
883        else:
884            test_data[name.lower()] = value.encode("ascii")
885
886    return data
887
888
889class WycheproofTest(object):
890    def __init__(self, testgroup, testcase):
891        self.testgroup = testgroup
892        self.testcase = testcase
893
894    def __repr__(self):
895        return "<WycheproofTest({!r}, {!r}, tcId={})>".format(
896            self.testgroup, self.testcase, self.testcase["tcId"],
897        )
898
899    @property
900    def valid(self):
901        return self.testcase["result"] == "valid"
902
903    @property
904    def acceptable(self):
905        return self.testcase["result"] == "acceptable"
906
907    @property
908    def invalid(self):
909        return self.testcase["result"] == "invalid"
910
911    def has_flag(self, flag):
912        return flag in self.testcase["flags"]
913
914
915def skip_if_wycheproof_none(wycheproof):
916    # This is factored into its own function so we can easily test both
917    # branches
918    if wycheproof is None:
919        pytest.skip("--wycheproof-root not provided")
920
921
922def load_wycheproof_tests(wycheproof, test_file):
923    path = os.path.join(wycheproof, "testvectors", test_file)
924    with open(path) as f:
925        data = json.load(f)
926        for group in data["testGroups"]:
927            cases = group.pop("tests")
928            for c in cases:
929                yield WycheproofTest(group, c)
930