1 /* Copyright (c) 2016, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <string>
16 
17 #include <stdint.h>
18 #include <stdio.h>
19 #include <string.h>
20 
21 #include <openssl/curve25519.h>
22 
23 #include "../internal.h"
24 
25 
26 struct SPAKE2Run {
RunSPAKE2Run27   bool Run() {
28     bssl::UniquePtr<SPAKE2_CTX> alice(SPAKE2_CTX_new(
29         spake2_role_alice,
30         reinterpret_cast<const uint8_t *>(alice_names.first.data()),
31         alice_names.first.size(),
32         reinterpret_cast<const uint8_t *>(alice_names.second.data()),
33         alice_names.second.size()));
34     bssl::UniquePtr<SPAKE2_CTX> bob(SPAKE2_CTX_new(
35         spake2_role_bob,
36         reinterpret_cast<const uint8_t *>(bob_names.first.data()),
37         bob_names.first.size(),
38         reinterpret_cast<const uint8_t *>(bob_names.second.data()),
39         bob_names.second.size()));
40 
41     if (!alice || !bob) {
42       return false;
43     }
44 
45     uint8_t alice_msg[SPAKE2_MAX_MSG_SIZE];
46     uint8_t bob_msg[SPAKE2_MAX_MSG_SIZE];
47     size_t alice_msg_len, bob_msg_len;
48 
49     if (!SPAKE2_generate_msg(
50             alice.get(), alice_msg, &alice_msg_len, sizeof(alice_msg),
51             reinterpret_cast<const uint8_t *>(alice_password.data()),
52             alice_password.size()) ||
53         !SPAKE2_generate_msg(
54             bob.get(), bob_msg, &bob_msg_len, sizeof(bob_msg),
55             reinterpret_cast<const uint8_t *>(bob_password.data()),
56             bob_password.size())) {
57       return false;
58     }
59 
60     if (alice_corrupt_msg_bit >= 0 &&
61         static_cast<size_t>(alice_corrupt_msg_bit) < 8 * alice_msg_len) {
62       alice_msg[alice_corrupt_msg_bit/8] ^= 1 << (alice_corrupt_msg_bit & 7);
63     }
64 
65     uint8_t alice_key[64], bob_key[64];
66     size_t alice_key_len, bob_key_len;
67 
68     if (!SPAKE2_process_msg(alice.get(), alice_key, &alice_key_len,
69                             sizeof(alice_key), bob_msg, bob_msg_len) ||
70         !SPAKE2_process_msg(bob.get(), bob_key, &bob_key_len, sizeof(bob_key),
71                             alice_msg, alice_msg_len)) {
72       return false;
73     }
74 
75     key_matches_ = (alice_key_len == bob_key_len &&
76                     OPENSSL_memcmp(alice_key, bob_key, alice_key_len) == 0);
77 
78     return true;
79   }
80 
key_matchesSPAKE2Run81   bool key_matches() const {
82     return key_matches_;
83   }
84 
85   std::string alice_password = "password";
86   std::string bob_password = "password";
87   std::pair<std::string, std::string> alice_names = {"alice", "bob"};
88   std::pair<std::string, std::string> bob_names = {"bob", "alice"};
89   int alice_corrupt_msg_bit = -1;
90 
91  private:
92   bool key_matches_ = false;
93 };
94 
TestSPAKE2()95 static bool TestSPAKE2() {
96   for (unsigned i = 0; i < 20; i++) {
97     SPAKE2Run spake2;
98     if (!spake2.Run()) {
99       fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
100       return false;
101     }
102 
103     if (!spake2.key_matches()) {
104       fprintf(stderr, "Key didn't match for equal passwords.\n");
105       return false;
106     }
107   }
108 
109   return true;
110 }
111 
TestWrongPassword()112 static bool TestWrongPassword() {
113   SPAKE2Run spake2;
114   spake2.bob_password = "wrong password";
115   if (!spake2.Run()) {
116     fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
117     return false;
118   }
119 
120   if (spake2.key_matches()) {
121     fprintf(stderr, "Key matched for unequal passwords.\n");
122     return false;
123   }
124 
125   return true;
126 }
127 
TestWrongNames()128 static bool TestWrongNames() {
129   SPAKE2Run spake2;
130   spake2.alice_names.second = "charlie";
131   spake2.bob_names.second = "charlie";
132   if (!spake2.Run()) {
133     fprintf(stderr, "TestSPAKE2: SPAKE2 failed.\n");
134     return false;
135   }
136 
137   if (spake2.key_matches()) {
138     fprintf(stderr, "Key matched for unequal names.\n");
139     return false;
140   }
141 
142   return true;
143 }
144 
TestCorruptMessages()145 static bool TestCorruptMessages() {
146   for (int i = 0; i < 8 * SPAKE2_MAX_MSG_SIZE; i++) {
147     SPAKE2Run spake2;
148     spake2.alice_corrupt_msg_bit = i;
149     if (spake2.Run() && spake2.key_matches()) {
150       fprintf(stderr, "Passed after corrupting Alice's message, bit %d\n", i);
151       return false;
152     }
153   }
154 
155   return true;
156 }
157 
158 /* TODO(agl): add tests with fixed vectors once SPAKE2 is nailed down. */
159 
main(int argc,char ** argv)160 int main(int argc, char **argv) {
161   if (!TestSPAKE2() ||
162       !TestWrongPassword() ||
163       !TestWrongNames() ||
164       !TestCorruptMessages()) {
165     return 1;
166   }
167 
168   printf("PASS\n");
169   return 0;
170 }
171