1 /* Copyright (c) 2016, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #include <string>
16
17 #include <stdint.h>
18 #include <stdio.h>
19 #include <string.h>
20
21 #include <openssl/curve25519.h>
22
23 #include "../internal.h"
24
25
26 struct SPAKE2Run {
RunSPAKE2Run27 bool Run() {
28 bssl::UniquePtr<SPAKE2_CTX> alice(SPAKE2_CTX_new(
29 spake2_role_alice,
30 reinterpret_cast<const uint8_t *>(alice_names.first.data()),
31 alice_names.first.size(),
32 reinterpret_cast<const uint8_t *>(alice_names.second.data()),
33 alice_names.second.size()));
34 bssl::UniquePtr<SPAKE2_CTX> bob(SPAKE2_CTX_new(
35 spake2_role_bob,
36 reinterpret_cast<const uint8_t *>(bob_names.first.data()),
37 bob_names.first.size(),
38 reinterpret_cast<const uint8_t *>(bob_names.second.data()),
39 bob_names.second.size()));
40
41 if (!alice || !bob) {
42 return false;
43 }
44
45 uint8_t alice_msg[SPAKE2_MAX_MSG_SIZE];
46 uint8_t bob_msg[SPAKE2_MAX_MSG_SIZE];
47 size_t alice_msg_len, bob_msg_len;
48
49 if (!SPAKE2_generate_msg(
50 alice.get(), alice_msg, &alice_msg_len, sizeof(alice_msg),
51 reinterpret_cast<const uint8_t *>(alice_password.data()),
52 alice_password.size()) ||
53 !SPAKE2_generate_msg(
54 bob.get(), bob_msg, &bob_msg_len, sizeof(bob_msg),
55 reinterpret_cast<const uint8_t *>(bob_password.data()),
56 bob_password.size())) {
57 return false;
58 }
59
60 if (alice_corrupt_msg_bit >= 0 &&
61 static_cast<size_t>(alice_corrupt_msg_bit) < 8 * alice_msg_len) {
62 alice_msg[alice_corrupt_msg_bit/8] ^= 1 << (alice_corrupt_msg_bit & 7);
63 }
64
65 uint8_t alice_key[64], bob_key[64];
66 size_t alice_key_len, bob_key_len;
67
68 if (!SPAKE2_process_msg(alice.get(), alice_key, &alice_key_len,
69 sizeof(alice_key), bob_msg, bob_msg_len) ||
70 !SPAKE2_process_msg(bob.get(), bob_key, &bob_key_len, sizeof(bob_key),
71 alice_msg, alice_msg_len)) {
72 return false;
73 }
74
75 key_matches_ = (alice_key_len == bob_key_len &&
76 OPENSSL_memcmp(alice_key, bob_key, alice_key_len) == 0);
77
78 return true;
79 }
80
key_matchesSPAKE2Run81 bool key_matches() const {
82 return key_matches_;
83 }
84
85 std::string alice_password = "password";
86 std::string bob_password = "password";
87 std::pair<std::string, std::string> alice_names = {"alice", "bob"};
88 std::pair<std::string, std::string> bob_names = {"bob", "alice"};
89 int alice_corrupt_msg_bit = -1;
90
91 private:
92 bool key_matches_ = false;
93 };
94
TestSPAKE2()95 static bool TestSPAKE2() {
96 for (unsigned i = 0; i < 20; i++) {
97 SPAKE2Run spake2;
98 if (!spake2.Run()) {
99 fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
100 return false;
101 }
102
103 if (!spake2.key_matches()) {
104 fprintf(stderr, "Key didn't match for equal passwords.\n");
105 return false;
106 }
107 }
108
109 return true;
110 }
111
TestWrongPassword()112 static bool TestWrongPassword() {
113 SPAKE2Run spake2;
114 spake2.bob_password = "wrong password";
115 if (!spake2.Run()) {
116 fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
117 return false;
118 }
119
120 if (spake2.key_matches()) {
121 fprintf(stderr, "Key matched for unequal passwords.\n");
122 return false;
123 }
124
125 return true;
126 }
127
TestWrongNames()128 static bool TestWrongNames() {
129 SPAKE2Run spake2;
130 spake2.alice_names.second = "charlie";
131 spake2.bob_names.second = "charlie";
132 if (!spake2.Run()) {
133 fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
134 return false;
135 }
136
137 if (spake2.key_matches()) {
138 fprintf(stderr, "Key matched for unequal names.\n");
139 return false;
140 }
141
142 return true;
143 }
144
TestCorruptMessages()145 static bool TestCorruptMessages() {
146 for (int i = 0; i < 8 * SPAKE2_MAX_MSG_SIZE; i++) {
147 SPAKE2Run spake2;
148 spake2.alice_corrupt_msg_bit = i;
149 if (spake2.Run() && spake2.key_matches()) {
150 fprintf(stderr, "Passed after corrupting Alice's message, bit %d\n", i);
151 return false;
152 }
153 }
154
155 return true;
156 }
157
158 /* TODO(agl): add tests with fixed vectors once SPAKE2 is nailed down. */
159
main(int argc,char ** argv)160 int main(int argc, char **argv) {
161 if (!TestSPAKE2() ||
162 !TestWrongPassword() ||
163 !TestWrongNames() ||
164 !TestCorruptMessages()) {
165 return 1;
166 }
167
168 printf("PASS\n");
169 return 0;
170 }
171