1 /* Copyright (c) 2014, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/base.h>
16 
17 #include <openssl/err.h>
18 #include <openssl/pem.h>
19 #include <openssl/ssl.h>
20 
21 #include "../crypto/test/scoped_types.h"
22 #include "../ssl/test/scoped_types.h"
23 #include "internal.h"
24 #include "transport_common.h"
25 
26 
27 static const struct argument kArguments[] = {
28     {
29      "-connect", kRequiredArgument,
30      "The hostname and port of the server to connect to, e.g. foo.com:443",
31     },
32     {
33      "-cipher", kOptionalArgument,
34      "An OpenSSL-style cipher suite string that configures the offered ciphers",
35     },
36     {
37      "-max-version", kOptionalArgument,
38      "The maximum acceptable protocol version",
39     },
40     {
41      "-min-version", kOptionalArgument,
42      "The minimum acceptable protocol version",
43     },
44     {
45      "-server-name", kOptionalArgument,
46      "The server name to advertise",
47     },
48     {
49      "-select-next-proto", kOptionalArgument,
50      "An NPN protocol to select if the server supports NPN",
51     },
52     {
53      "-alpn-protos", kOptionalArgument,
54      "A comma-separated list of ALPN protocols to advertise",
55     },
56     {
57      "-fallback-scsv", kBooleanArgument,
58      "Enable FALLBACK_SCSV",
59     },
60     {
61      "-ocsp-stapling", kBooleanArgument,
62      "Advertise support for OCSP stabling",
63     },
64     {
65      "-signed-certificate-timestamps", kBooleanArgument,
66      "Advertise support for signed certificate timestamps",
67     },
68     {
69      "-channel-id-key", kOptionalArgument,
70      "The key to use for signing a channel ID",
71     },
72     {
73      "", kOptionalArgument, "",
74     },
75 };
76 
LoadPrivateKey(const std::string & file)77 static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) {
78   ScopedBIO bio(BIO_new(BIO_s_file()));
79   if (!bio || !BIO_read_filename(bio.get(), file.c_str())) {
80     return nullptr;
81   }
82   ScopedEVP_PKEY pkey(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr,
83                                               nullptr));
84   return pkey;
85 }
86 
VersionFromString(uint16_t * out_version,const std::string & version)87 static bool VersionFromString(uint16_t *out_version,
88                               const std::string& version) {
89   if (version == "ssl3") {
90     *out_version = SSL3_VERSION;
91     return true;
92   } else if (version == "tls1" || version == "tls1.0") {
93     *out_version = TLS1_VERSION;
94     return true;
95   } else if (version == "tls1.1") {
96     *out_version = TLS1_1_VERSION;
97     return true;
98   } else if (version == "tls1.2") {
99     *out_version = TLS1_2_VERSION;
100     return true;
101   }
102   return false;
103 }
104 
NextProtoSelectCallback(SSL * ssl,uint8_t ** out,uint8_t * outlen,const uint8_t * in,unsigned inlen,void * arg)105 static int NextProtoSelectCallback(SSL* ssl, uint8_t** out, uint8_t* outlen,
106                                    const uint8_t* in, unsigned inlen, void* arg) {
107   *out = reinterpret_cast<uint8_t *>(arg);
108   *outlen = strlen(reinterpret_cast<const char *>(arg));
109   return SSL_TLSEXT_ERR_OK;
110 }
111 
Client(const std::vector<std::string> & args)112 bool Client(const std::vector<std::string> &args) {
113   if (!InitSocketLibrary()) {
114     return false;
115   }
116 
117   std::map<std::string, std::string> args_map;
118 
119   if (!ParseKeyValueArguments(&args_map, args, kArguments)) {
120     PrintUsage(kArguments);
121     return false;
122   }
123 
124   ScopedSSL_CTX ctx(SSL_CTX_new(SSLv23_client_method()));
125 
126   const char *keylog_file = getenv("SSLKEYLOGFILE");
127   if (keylog_file) {
128     BIO *keylog_bio = BIO_new_file(keylog_file, "a");
129     if (!keylog_bio) {
130       ERR_print_errors_cb(PrintErrorCallback, stderr);
131       return false;
132     }
133     SSL_CTX_set_keylog_bio(ctx.get(), keylog_bio);
134   }
135 
136   if (args_map.count("-cipher") != 0 &&
137       !SSL_CTX_set_cipher_list(ctx.get(), args_map["-cipher"].c_str())) {
138     fprintf(stderr, "Failed setting cipher list\n");
139     return false;
140   }
141 
142   if (args_map.count("-max-version") != 0) {
143     uint16_t version;
144     if (!VersionFromString(&version, args_map["-max-version"])) {
145       fprintf(stderr, "Unknown protocol version: '%s'\n",
146               args_map["-max-version"].c_str());
147       return false;
148     }
149     SSL_CTX_set_max_version(ctx.get(), version);
150   }
151 
152   if (args_map.count("-min-version") != 0) {
153     uint16_t version;
154     if (!VersionFromString(&version, args_map["-min-version"])) {
155       fprintf(stderr, "Unknown protocol version: '%s'\n",
156               args_map["-min-version"].c_str());
157       return false;
158     }
159     SSL_CTX_set_min_version(ctx.get(), version);
160   }
161 
162   if (args_map.count("-select-next-proto") != 0) {
163     const std::string &proto = args_map["-select-next-proto"];
164     if (proto.size() > 255) {
165       fprintf(stderr, "Bad NPN protocol: '%s'\n", proto.c_str());
166       return false;
167     }
168     // |SSL_CTX_set_next_proto_select_cb| is not const-correct.
169     SSL_CTX_set_next_proto_select_cb(ctx.get(), NextProtoSelectCallback,
170                                      const_cast<char *>(proto.c_str()));
171   }
172 
173   if (args_map.count("-alpn-protos") != 0) {
174     const std::string &alpn_protos = args_map["-alpn-protos"];
175     std::vector<uint8_t> wire;
176     size_t i = 0;
177     while (i <= alpn_protos.size()) {
178       size_t j = alpn_protos.find(',', i);
179       if (j == std::string::npos) {
180         j = alpn_protos.size();
181       }
182       size_t len = j - i;
183       if (len > 255) {
184         fprintf(stderr, "Invalid ALPN protocols: '%s'\n", alpn_protos.c_str());
185         return false;
186       }
187       wire.push_back(static_cast<uint8_t>(len));
188       wire.resize(wire.size() + len);
189       memcpy(wire.data() + wire.size() - len, alpn_protos.data() + i, len);
190       i = j + 1;
191     }
192     if (SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) != 0) {
193       return false;
194     }
195   }
196 
197   if (args_map.count("-fallback-scsv") != 0) {
198     SSL_CTX_set_mode(ctx.get(), SSL_MODE_SEND_FALLBACK_SCSV);
199   }
200 
201   if (args_map.count("-ocsp-stapling") != 0) {
202     SSL_CTX_enable_ocsp_stapling(ctx.get());
203   }
204 
205   if (args_map.count("-signed-certificate-timestamps") != 0) {
206     SSL_CTX_enable_signed_cert_timestamps(ctx.get());
207   }
208 
209   if (args_map.count("-channel-id-key") != 0) {
210     ScopedEVP_PKEY pkey = LoadPrivateKey(args_map["-channel-id-key"]);
211     if (!pkey || !SSL_CTX_set1_tls_channel_id(ctx.get(), pkey.get())) {
212       return false;
213     }
214     ctx->tlsext_channel_id_enabled_new = 1;
215   }
216 
217   int sock = -1;
218   if (!Connect(&sock, args_map["-connect"])) {
219     return false;
220   }
221 
222   ScopedBIO bio(BIO_new_socket(sock, BIO_CLOSE));
223   ScopedSSL ssl(SSL_new(ctx.get()));
224 
225   if (args_map.count("-server-name") != 0) {
226     SSL_set_tlsext_host_name(ssl.get(), args_map["-server-name"].c_str());
227   }
228 
229   SSL_set_bio(ssl.get(), bio.get(), bio.get());
230   bio.release();
231 
232   int ret = SSL_connect(ssl.get());
233   if (ret != 1) {
234     int ssl_err = SSL_get_error(ssl.get(), ret);
235     fprintf(stderr, "Error while connecting: %d\n", ssl_err);
236     ERR_print_errors_cb(PrintErrorCallback, stderr);
237     return false;
238   }
239 
240   fprintf(stderr, "Connected.\n");
241   PrintConnectionInfo(ssl.get());
242 
243   bool ok = TransferData(ssl.get(), sock);
244 
245   return ok;
246 }
247