1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import sys
8
9from cryptography import utils
10from cryptography.exceptions import (
11    AlreadyFinalized, InvalidKey, UnsupportedAlgorithm, _Reasons
12)
13from cryptography.hazmat.backends.interfaces import ScryptBackend
14from cryptography.hazmat.primitives import constant_time
15from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
16
17
18# This is used by the scrypt tests to skip tests that require more memory
19# than the MEM_LIMIT
20_MEM_LIMIT = sys.maxsize // 2
21
22
23@utils.register_interface(KeyDerivationFunction)
24class Scrypt(object):
25    def __init__(self, salt, length, n, r, p, backend):
26        if not isinstance(backend, ScryptBackend):
27            raise UnsupportedAlgorithm(
28                "Backend object does not implement ScryptBackend.",
29                _Reasons.BACKEND_MISSING_INTERFACE
30            )
31
32        self._length = length
33        utils._check_bytes("salt", salt)
34        if n < 2 or (n & (n - 1)) != 0:
35            raise ValueError("n must be greater than 1 and be a power of 2.")
36
37        if r < 1:
38            raise ValueError("r must be greater than or equal to 1.")
39
40        if p < 1:
41            raise ValueError("p must be greater than or equal to 1.")
42
43        self._used = False
44        self._salt = salt
45        self._n = n
46        self._r = r
47        self._p = p
48        self._backend = backend
49
50    def derive(self, key_material):
51        if self._used:
52            raise AlreadyFinalized("Scrypt instances can only be used once.")
53        self._used = True
54
55        utils._check_byteslike("key_material", key_material)
56        return self._backend.derive_scrypt(
57            key_material, self._salt, self._length, self._n, self._r, self._p
58        )
59
60    def verify(self, key_material, expected_key):
61        derived_key = self.derive(key_material)
62        if not constant_time.bytes_eq(derived_key, expected_key):
63            raise InvalidKey("Keys do not match.")
64