1 // Copyright 2012 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "polo/pairing/pairingsession.h"
16 
17 #include <glog/logging.h>
18 #include "polo/encoding/hexadecimalencoder.h"
19 #include "polo/util/poloutil.h"
20 
21 namespace polo {
22 namespace pairing {
23 
PairingSession(wire::PoloWireAdapter * wire,PairingContext * context,PoloChallengeResponse * challenge)24 PairingSession::PairingSession(wire::PoloWireAdapter* wire,
25                                PairingContext* context,
26                                PoloChallengeResponse* challenge)
27     : state_(kUninitialized),
28       wire_(wire),
29       context_(context),
30       challenge_(challenge),
31       configuration_(NULL),
32       encoder_(NULL),
33       nonce_(NULL),
34       secret_(NULL) {
35   wire_->set_listener(this);
36 
37   local_options_.set_protocol_role_preference(context->is_server() ?
38       message::OptionsMessage::kDisplayDevice
39       : message::OptionsMessage::kInputDevice);
40 }
41 
~PairingSession()42 PairingSession::~PairingSession() {
43   if (configuration_) {
44     delete configuration_;
45   }
46 
47   if (encoder_) {
48     delete encoder_;
49   }
50 
51   if (nonce_) {
52     delete nonce_;
53   }
54 
55   if (secret_) {
56     delete secret_;
57   }
58 }
59 
AddInputEncoding(const encoding::EncodingOption & encoding)60 void PairingSession::AddInputEncoding(
61     const encoding::EncodingOption& encoding) {
62   if (state_ != kUninitialized) {
63     LOG(ERROR) << "Attempt to add input encoding to active session";
64     return;
65   }
66 
67   if (!IsValidEncodingOption(encoding)) {
68     LOG(ERROR) << "Invalid input encoding: " << encoding.ToString();
69     return;
70   }
71 
72   local_options_.AddInputEncoding(encoding);
73 }
74 
AddOutputEncoding(const encoding::EncodingOption & encoding)75 void PairingSession::AddOutputEncoding(
76     const encoding::EncodingOption& encoding) {
77   if (state_ != kUninitialized) {
78     LOG(ERROR) << "Attempt to add output encoding to active session";
79     return;
80   }
81 
82   if (!IsValidEncodingOption(encoding)) {
83     LOG(ERROR) << "Invalid output encoding: " << encoding.ToString();
84     return;
85   }
86 
87   local_options_.AddOutputEncoding(encoding);
88 }
89 
SetSecret(const Gamma & secret)90 bool PairingSession::SetSecret(const Gamma& secret) {
91   secret_ = new Gamma(secret);
92 
93   if (!IsInputDevice() || state_ != kWaitingForSecret) {
94     LOG(ERROR) << "Invalid state: unexpected secret";
95     return false;
96   }
97 
98   if (!challenge().CheckGamma(secret)) {
99     LOG(ERROR) << "Secret failed local check";
100     return false;
101   }
102 
103   nonce_ = challenge().ExtractNonce(secret);
104   if (!nonce_) {
105     LOG(ERROR) << "Failed to extract nonce";
106     return false;
107   }
108 
109   const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
110   if (!gen_alpha) {
111     LOG(ERROR) << "Failed to get alpha";
112     return false;
113   }
114 
115   message::SecretMessage secret_message(*gen_alpha);
116   delete gen_alpha;
117 
118   wire_->SendSecretMessage(secret_message);
119 
120   LOG(INFO) << "Waiting for SecretAck...";
121   wire_->GetNextMessage();
122 
123   return true;
124 }
125 
DoPair(PairingListener * listener)126 void PairingSession::DoPair(PairingListener *listener) {
127   listener_ = listener;
128   listener_->OnSessionCreated();
129 
130   if (context_->is_server()) {
131     LOG(INFO) << "Pairing started (SERVER mode)";
132   } else {
133     LOG(INFO) << "Pairing started (CLIENT mode)";
134   }
135   LOG(INFO) << "Local options: " << local_options_.ToString();
136 
137   set_state(kInitializing);
138   DoInitializationPhase();
139 }
140 
DoPairingPhase()141 void PairingSession::DoPairingPhase() {
142   if (IsInputDevice()) {
143     DoInputPairing();
144   } else {
145     DoOutputPairing();
146   }
147 }
148 
DoInputPairing()149 void PairingSession::DoInputPairing() {
150   set_state(kWaitingForSecret);
151   listener_->OnPerformInputDeviceRole();
152 }
153 
DoOutputPairing()154 void PairingSession::DoOutputPairing() {
155   size_t nonce_length = configuration_->encoding().symbol_length() / 2;
156   size_t bytes_needed = nonce_length / encoder_->symbols_per_byte();
157 
158   uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed);
159   nonce_ = new Nonce(random, random + bytes_needed);
160   delete[] random;
161 
162   const Gamma* gamma = challenge().GetGamma(*nonce_);
163   if (!gamma) {
164     LOG(ERROR) << "Failed to get gamma";
165     wire()->SendErrorMessage(kErrorProtocol);
166     listener()->OnError(kErrorProtocol);
167     return;
168   }
169 
170   listener_->OnPerformOutputDeviceRole(*gamma);
171   delete gamma;
172 
173   set_state(kWaitingForSecret);
174 
175   LOG(INFO) << "Waiting for Secret...";
176   wire_->GetNextMessage();
177 }
178 
set_state(ProtocolState state)179 void PairingSession::set_state(ProtocolState state) {
180   LOG(INFO) << "New state: " << state;
181   state_ = state;
182 }
183 
SetConfiguration(const message::ConfigurationMessage & message)184 bool PairingSession::SetConfiguration(
185     const message::ConfigurationMessage& message) {
186   const encoding::EncodingOption& encoding = message.encoding();
187 
188   if (!IsValidEncodingOption(encoding)) {
189     LOG(ERROR) << "Invalid configuration: " << encoding.ToString();
190     return false;
191   }
192 
193   if (encoder_) {
194     delete encoder_;
195     encoder_ = NULL;
196   }
197 
198   switch (encoding.encoding_type()) {
199     case encoding::EncodingOption::kHexadecimal:
200       encoder_ = new encoding::HexadecimalEncoder();
201       break;
202     default:
203       LOG(ERROR) << "Unsupported encoding type: "
204           << encoding.encoding_type();
205       return false;
206   }
207 
208   if (configuration_) {
209     delete configuration_;
210   }
211   configuration_ = new message::ConfigurationMessage(message.encoding(),
212                                                      message.client_role());
213   return true;
214 }
215 
OnSecretMessage(const message::SecretMessage & message)216 void PairingSession::OnSecretMessage(const message::SecretMessage& message) {
217   if (state() != kWaitingForSecret) {
218     LOG(ERROR) << "Invalid state: unexpected secret message";
219     wire()->SendErrorMessage(kErrorProtocol);
220     listener()->OnError(kErrorProtocol);
221     return;
222   }
223 
224   if (!VerifySecret(message.secret())) {
225     wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
226     listener_->OnError(kErrorInvalidChallengeResponse);
227     return;
228   }
229 
230   const Alpha* alpha = challenge().GetAlpha(*nonce_);
231   if (!alpha) {
232     LOG(ERROR) << "Failed to get alpha";
233     wire()->SendErrorMessage(kErrorProtocol);
234     listener()->OnError(kErrorProtocol);
235     return;
236   }
237 
238   message::SecretAckMessage ack(*alpha);
239   delete alpha;
240 
241   wire_->SendSecretAckMessage(ack);
242 
243   listener_->OnPairingSuccess();
244 }
245 
OnSecretAckMessage(const message::SecretAckMessage & message)246 void PairingSession::OnSecretAckMessage(
247     const message::SecretAckMessage& message) {
248   if (kVerifySecretAck && !VerifySecret(message.secret())) {
249     wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
250     listener_->OnError(kErrorInvalidChallengeResponse);
251     return;
252   }
253 
254   listener_->OnPairingSuccess();
255 }
256 
OnError(pairing::PoloError error)257 void PairingSession::OnError(pairing::PoloError error) {
258   listener_->OnError(error);
259 }
260 
VerifySecret(const Alpha & secret) const261 bool PairingSession::VerifySecret(const Alpha& secret) const {
262   if (!nonce_) {
263     LOG(ERROR) << "Nonce not set";
264     return false;
265   }
266 
267   const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
268   if (!gen_alpha) {
269     LOG(ERROR) << "Failed to get alpha";
270     return false;
271   }
272 
273   bool valid = (secret == *gen_alpha);
274 
275   if (!valid) {
276     LOG(ERROR) << "Inband secret did not match. Expected ["
277         << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size())
278         << "], got ["
279         << util::PoloUtil::BytesToHexString(&secret[0], secret.size())
280         << "]";
281   }
282 
283   delete gen_alpha;
284   return valid;
285 }
286 
GetLocalRole() const287 message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const {
288   if (!configuration_) {
289     return message::OptionsMessage::kUnknown;
290   }
291 
292   if (context_->is_client()) {
293     return configuration_->client_role();
294   } else {
295     return configuration_->client_role() ==
296         message::OptionsMessage::kDisplayDevice ?
297             message::OptionsMessage::kInputDevice
298             : message::OptionsMessage::kDisplayDevice;
299   }
300 }
301 
IsInputDevice() const302 bool PairingSession::IsInputDevice() const {
303   return GetLocalRole() == message::OptionsMessage::kInputDevice;
304 }
305 
IsValidEncodingOption(const encoding::EncodingOption & option) const306 bool PairingSession::IsValidEncodingOption(
307     const encoding::EncodingOption& option) const {
308   // Legal values of GAMMALEN must be an even number of at least 2 bytes.
309   return option.encoding_type() != encoding::EncodingOption::kUnknown
310       && (option.symbol_length() % 2 == 0)
311       && (option.symbol_length() >= 2);
312 }
313 
314 }  // namespace pairing
315 }  // namespace polo
316