1 /*
2  *  Copyright 2003 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 // Registry configuration wrapers class implementation
12 //
13 // Change made by S. Ganesh - ganesh@google.com:
14 //   Use SHQueryValueEx instead of RegQueryValueEx throughout.
15 //   A call to the SHLWAPI function is essentially a call to the standard
16 //   function but with post-processing:
17 //   * to fix REG_SZ or REG_EXPAND_SZ data that is not properly null-terminated;
18 //   * to expand REG_EXPAND_SZ data.
19 
20 #include "webrtc/base/win32regkey.h"
21 
22 #include <shlwapi.h>
23 
24 #include "webrtc/base/common.h"
25 #include "webrtc/base/logging.h"
26 #include "webrtc/base/scoped_ptr.h"
27 
28 namespace rtc {
29 
RegKey()30 RegKey::RegKey() {
31   h_key_ = NULL;
32 }
33 
~RegKey()34 RegKey::~RegKey() {
35   Close();
36 }
37 
Create(HKEY parent_key,const wchar_t * key_name)38 HRESULT RegKey::Create(HKEY parent_key, const wchar_t* key_name) {
39   return Create(parent_key,
40                 key_name,
41                 REG_NONE,
42                 REG_OPTION_NON_VOLATILE,
43                 KEY_ALL_ACCESS,
44                 NULL,
45                 NULL);
46 }
47 
Open(HKEY parent_key,const wchar_t * key_name)48 HRESULT RegKey::Open(HKEY parent_key, const wchar_t* key_name) {
49   return Open(parent_key, key_name, KEY_ALL_ACCESS);
50 }
51 
HasValue(const TCHAR * value_name) const52 bool RegKey::HasValue(const TCHAR* value_name) const {
53   return (ERROR_SUCCESS == ::RegQueryValueEx(h_key_, value_name, NULL,
54                                              NULL, NULL, NULL));
55 }
56 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD value)57 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
58                          const wchar_t* value_name,
59                          DWORD value) {
60   ASSERT(full_key_name != NULL);
61 
62   return SetValueStaticHelper(full_key_name, value_name, REG_DWORD, &value);
63 }
64 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD64 value)65 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
66                          const wchar_t* value_name,
67                          DWORD64 value) {
68   ASSERT(full_key_name != NULL);
69 
70   return SetValueStaticHelper(full_key_name, value_name, REG_QWORD, &value);
71 }
72 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,float value)73 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
74                          const wchar_t* value_name,
75                          float value) {
76   ASSERT(full_key_name != NULL);
77 
78   return SetValueStaticHelper(full_key_name, value_name,
79                               REG_BINARY, &value, sizeof(value));
80 }
81 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,double value)82 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
83                          const wchar_t* value_name,
84                          double value) {
85   ASSERT(full_key_name != NULL);
86 
87   return SetValueStaticHelper(full_key_name, value_name,
88                               REG_BINARY, &value, sizeof(value));
89 }
90 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,const TCHAR * value)91 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
92                          const wchar_t* value_name,
93                          const TCHAR* value) {
94   ASSERT(full_key_name != NULL);
95   ASSERT(value != NULL);
96 
97   return SetValueStaticHelper(full_key_name, value_name,
98                               REG_SZ, const_cast<wchar_t*>(value));
99 }
100 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,const uint8_t * value,DWORD byte_count)101 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
102                          const wchar_t* value_name,
103                          const uint8_t* value,
104                          DWORD byte_count) {
105   ASSERT(full_key_name != NULL);
106 
107   return SetValueStaticHelper(full_key_name, value_name, REG_BINARY,
108                               const_cast<uint8_t*>(value), byte_count);
109 }
110 
SetValueMultiSZ(const wchar_t * full_key_name,const wchar_t * value_name,const uint8_t * value,DWORD byte_count)111 HRESULT RegKey::SetValueMultiSZ(const wchar_t* full_key_name,
112                                 const wchar_t* value_name,
113                                 const uint8_t* value,
114                                 DWORD byte_count) {
115   ASSERT(full_key_name != NULL);
116 
117   return SetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ,
118                               const_cast<uint8_t*>(value), byte_count);
119 }
120 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD * value)121 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
122                          const wchar_t* value_name,
123                          DWORD* value) {
124   ASSERT(full_key_name != NULL);
125   ASSERT(value != NULL);
126 
127   return GetValueStaticHelper(full_key_name, value_name, REG_DWORD, value);
128 }
129 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD64 * value)130 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
131                          const wchar_t* value_name,
132                          DWORD64* value) {
133   ASSERT(full_key_name != NULL);
134   ASSERT(value != NULL);
135 
136   return GetValueStaticHelper(full_key_name, value_name, REG_QWORD, value);
137 }
138 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,float * value)139 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
140                          const wchar_t* value_name,
141                          float* value) {
142   ASSERT(value != NULL);
143   ASSERT(full_key_name != NULL);
144 
145   DWORD byte_count = 0;
146   scoped_ptr<byte[]> buffer;
147   HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
148                                     REG_BINARY, buffer.accept(), &byte_count);
149   if (SUCCEEDED(hr)) {
150     ASSERT(byte_count == sizeof(*value));
151     if (byte_count == sizeof(*value)) {
152       *value = *reinterpret_cast<float*>(buffer.get());
153     }
154   }
155   return hr;
156 }
157 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,double * value)158 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
159                          const wchar_t* value_name,
160                          double* value) {
161   ASSERT(value != NULL);
162   ASSERT(full_key_name != NULL);
163 
164   DWORD byte_count = 0;
165   scoped_ptr<byte[]> buffer;
166   HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
167                                     REG_BINARY, buffer.accept(), &byte_count);
168   if (SUCCEEDED(hr)) {
169     ASSERT(byte_count == sizeof(*value));
170     if (byte_count == sizeof(*value)) {
171       *value = *reinterpret_cast<double*>(buffer.get());
172     }
173   }
174   return hr;
175 }
176 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,wchar_t ** value)177 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
178                          const wchar_t* value_name,
179                          wchar_t** value) {
180   ASSERT(full_key_name != NULL);
181   ASSERT(value != NULL);
182 
183   return GetValueStaticHelper(full_key_name, value_name, REG_SZ, value);
184 }
185 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,std::wstring * value)186 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
187                          const wchar_t* value_name,
188                          std::wstring* value) {
189   ASSERT(full_key_name != NULL);
190   ASSERT(value != NULL);
191 
192   scoped_ptr<wchar_t[]> buffer;
193   HRESULT hr = RegKey::GetValue(full_key_name, value_name, buffer.accept());
194   if (SUCCEEDED(hr)) {
195     value->assign(buffer.get());
196   }
197   return hr;
198 }
199 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,std::vector<std::wstring> * value)200 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
201                          const wchar_t* value_name,
202                          std::vector<std::wstring>* value) {
203   ASSERT(full_key_name != NULL);
204   ASSERT(value != NULL);
205 
206   return GetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ, value);
207 }
208 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,uint8_t ** value,DWORD * byte_count)209 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
210                          const wchar_t* value_name,
211                          uint8_t** value,
212                          DWORD* byte_count) {
213   ASSERT(full_key_name != NULL);
214   ASSERT(value != NULL);
215   ASSERT(byte_count != NULL);
216 
217   return GetValueStaticHelper(full_key_name, value_name,
218                               REG_BINARY, value, byte_count);
219 }
220 
DeleteSubKey(const wchar_t * key_name)221 HRESULT RegKey::DeleteSubKey(const wchar_t* key_name) {
222   ASSERT(key_name != NULL);
223   ASSERT(h_key_ != NULL);
224 
225   LONG res = ::RegDeleteKey(h_key_, key_name);
226   HRESULT hr = HRESULT_FROM_WIN32(res);
227   if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
228       hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
229     hr = S_FALSE;
230   }
231   return hr;
232 }
233 
DeleteValue(const wchar_t * value_name)234 HRESULT RegKey::DeleteValue(const wchar_t* value_name) {
235   ASSERT(h_key_ != NULL);
236 
237   LONG res = ::RegDeleteValue(h_key_, value_name);
238   HRESULT hr = HRESULT_FROM_WIN32(res);
239   if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
240       hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
241     hr = S_FALSE;
242   }
243   return hr;
244 }
245 
Close()246 HRESULT RegKey::Close() {
247   HRESULT hr = S_OK;
248   if (h_key_ != NULL) {
249     LONG res = ::RegCloseKey(h_key_);
250     hr = HRESULT_FROM_WIN32(res);
251     h_key_ = NULL;
252   }
253   return hr;
254 }
255 
Create(HKEY parent_key,const wchar_t * key_name,wchar_t * lpszClass,DWORD options,REGSAM sam_desired,LPSECURITY_ATTRIBUTES lpSecAttr,LPDWORD lpdwDisposition)256 HRESULT RegKey::Create(HKEY parent_key,
257                        const wchar_t* key_name,
258                        wchar_t* lpszClass,
259                        DWORD options,
260                        REGSAM sam_desired,
261                        LPSECURITY_ATTRIBUTES lpSecAttr,
262                        LPDWORD lpdwDisposition) {
263   ASSERT(key_name != NULL);
264   ASSERT(parent_key != NULL);
265 
266   DWORD dw = 0;
267   HKEY h_key = NULL;
268   LONG res = ::RegCreateKeyEx(parent_key, key_name, 0, lpszClass, options,
269                               sam_desired, lpSecAttr, &h_key, &dw);
270   HRESULT hr = HRESULT_FROM_WIN32(res);
271 
272   if (lpdwDisposition) {
273     *lpdwDisposition = dw;
274   }
275 
276   // we have to close the currently opened key
277   // before replacing it with the new one
278   if (hr == S_OK) {
279     hr = Close();
280     ASSERT(hr == S_OK);
281     h_key_ = h_key;
282   }
283   return hr;
284 }
285 
Open(HKEY parent_key,const wchar_t * key_name,REGSAM sam_desired)286 HRESULT RegKey::Open(HKEY parent_key,
287                      const wchar_t* key_name,
288                      REGSAM sam_desired) {
289   ASSERT(key_name != NULL);
290   ASSERT(parent_key != NULL);
291 
292   HKEY h_key = NULL;
293   LONG res = ::RegOpenKeyEx(parent_key, key_name, 0, sam_desired, &h_key);
294   HRESULT hr = HRESULT_FROM_WIN32(res);
295 
296   // we have to close the currently opened key
297   // before replacing it with the new one
298   if (hr == S_OK) {
299     // close the currently opened key if any
300     hr = Close();
301     ASSERT(hr == S_OK);
302     h_key_ = h_key;
303   }
304   return hr;
305 }
306 
307 // save the key and all of its subkeys and values to a file
Save(const wchar_t * full_key_name,const wchar_t * file_name)308 HRESULT RegKey::Save(const wchar_t* full_key_name, const wchar_t* file_name) {
309   ASSERT(full_key_name != NULL);
310   ASSERT(file_name != NULL);
311 
312   std::wstring key_name(full_key_name);
313   HKEY h_key = GetRootKeyInfo(&key_name);
314   if (!h_key) {
315     return E_FAIL;
316   }
317 
318   RegKey key;
319   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
320   if (FAILED(hr)) {
321     return hr;
322   }
323 
324   AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, true);
325   LONG res = ::RegSaveKey(key.h_key_, file_name, NULL);
326   AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, false);
327 
328   return HRESULT_FROM_WIN32(res);
329 }
330 
331 // restore the key and all of its subkeys and values which are saved into a file
Restore(const wchar_t * full_key_name,const wchar_t * file_name)332 HRESULT RegKey::Restore(const wchar_t* full_key_name,
333                         const wchar_t* file_name) {
334   ASSERT(full_key_name != NULL);
335   ASSERT(file_name != NULL);
336 
337   std::wstring key_name(full_key_name);
338   HKEY h_key = GetRootKeyInfo(&key_name);
339   if (!h_key) {
340     return E_FAIL;
341   }
342 
343   RegKey key;
344   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_WRITE);
345   if (FAILED(hr)) {
346     return hr;
347   }
348 
349   AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, true);
350   LONG res = ::RegRestoreKey(key.h_key_, file_name, REG_FORCE_RESTORE);
351   AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, false);
352 
353   return HRESULT_FROM_WIN32(res);
354 }
355 
356 // check if the current key has the specified subkey
HasSubkey(const wchar_t * key_name) const357 bool RegKey::HasSubkey(const wchar_t* key_name) const {
358   ASSERT(key_name != NULL);
359 
360   RegKey key;
361   HRESULT hr = key.Open(h_key_, key_name, KEY_READ);
362   key.Close();
363   return hr == S_OK;
364 }
365 
366 // static flush key
FlushKey(const wchar_t * full_key_name)367 HRESULT RegKey::FlushKey(const wchar_t* full_key_name) {
368   ASSERT(full_key_name != NULL);
369 
370   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
371   // get the root HKEY
372   std::wstring key_name(full_key_name);
373   HKEY h_key = GetRootKeyInfo(&key_name);
374 
375   if (h_key != NULL) {
376     LONG res = ::RegFlushKey(h_key);
377     hr = HRESULT_FROM_WIN32(res);
378   }
379   return hr;
380 }
381 
382 // static SET helper
SetValueStaticHelper(const wchar_t * full_key_name,const wchar_t * value_name,DWORD type,LPVOID value,DWORD byte_count)383 HRESULT RegKey::SetValueStaticHelper(const wchar_t* full_key_name,
384                                      const wchar_t* value_name,
385                                      DWORD type,
386                                      LPVOID value,
387                                      DWORD byte_count) {
388   ASSERT(full_key_name != NULL);
389 
390   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
391   // get the root HKEY
392   std::wstring key_name(full_key_name);
393   HKEY h_key = GetRootKeyInfo(&key_name);
394 
395   if (h_key != NULL) {
396     RegKey key;
397     hr = key.Create(h_key, key_name.c_str());
398     if (hr == S_OK) {
399       switch (type) {
400         case REG_DWORD:
401           hr = key.SetValue(value_name, *(static_cast<DWORD*>(value)));
402           break;
403         case REG_QWORD:
404           hr = key.SetValue(value_name, *(static_cast<DWORD64*>(value)));
405           break;
406         case REG_SZ:
407           hr = key.SetValue(value_name, static_cast<const wchar_t*>(value));
408           break;
409         case REG_BINARY:
410           hr = key.SetValue(value_name, static_cast<const uint8_t*>(value),
411                             byte_count);
412           break;
413         case REG_MULTI_SZ:
414           hr = key.SetValue(value_name, static_cast<const uint8_t*>(value),
415                             byte_count, type);
416           break;
417         default:
418           ASSERT(false);
419           hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
420           break;
421       }
422       // close the key after writing
423       HRESULT temp_hr = key.Close();
424       if (hr == S_OK) {
425         hr = temp_hr;
426       }
427     }
428   }
429   return hr;
430 }
431 
432 // static GET helper
GetValueStaticHelper(const wchar_t * full_key_name,const wchar_t * value_name,DWORD type,LPVOID value,DWORD * byte_count)433 HRESULT RegKey::GetValueStaticHelper(const wchar_t* full_key_name,
434                                      const wchar_t* value_name,
435                                      DWORD type,
436                                      LPVOID value,
437                                      DWORD* byte_count) {
438   ASSERT(full_key_name != NULL);
439 
440   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
441   // get the root HKEY
442   std::wstring key_name(full_key_name);
443   HKEY h_key = GetRootKeyInfo(&key_name);
444 
445   if (h_key != NULL) {
446     RegKey key;
447     hr = key.Open(h_key, key_name.c_str(), KEY_READ);
448     if (hr == S_OK) {
449       switch (type) {
450         case REG_DWORD:
451           hr = key.GetValue(value_name, reinterpret_cast<DWORD*>(value));
452           break;
453         case REG_QWORD:
454           hr = key.GetValue(value_name, reinterpret_cast<DWORD64*>(value));
455           break;
456         case REG_SZ:
457           hr = key.GetValue(value_name, reinterpret_cast<wchar_t**>(value));
458           break;
459         case REG_MULTI_SZ:
460           hr = key.GetValue(value_name, reinterpret_cast<
461                                             std::vector<std::wstring>*>(value));
462           break;
463         case REG_BINARY:
464           hr = key.GetValue(value_name, reinterpret_cast<uint8_t**>(value),
465                             byte_count);
466           break;
467         default:
468           ASSERT(false);
469           hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
470           break;
471       }
472       // close the key after writing
473       HRESULT temp_hr = key.Close();
474       if (hr == S_OK) {
475         hr = temp_hr;
476       }
477     }
478   }
479   return hr;
480 }
481 
482 // GET helper
GetValueHelper(const wchar_t * value_name,DWORD * type,uint8_t ** value,DWORD * byte_count) const483 HRESULT RegKey::GetValueHelper(const wchar_t* value_name,
484                                DWORD* type,
485                                uint8_t** value,
486                                DWORD* byte_count) const {
487   ASSERT(byte_count != NULL);
488   ASSERT(value != NULL);
489   ASSERT(type != NULL);
490 
491   // init return buffer
492   *value = NULL;
493 
494   // get the size of the return data buffer
495   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, type, NULL, byte_count);
496   HRESULT hr = HRESULT_FROM_WIN32(res);
497 
498   if (hr == S_OK) {
499     // if the value length is 0, nothing to do
500     if (*byte_count != 0) {
501       // allocate the buffer
502       *value = new byte[*byte_count];
503       ASSERT(*value != NULL);
504 
505       // make the call again to get the data
506       res = ::SHQueryValueEx(h_key_, value_name, NULL,
507                              type, *value, byte_count);
508       hr = HRESULT_FROM_WIN32(res);
509       ASSERT(hr == S_OK);
510     }
511   }
512   return hr;
513 }
514 
515 // Int32 Get
GetValue(const wchar_t * value_name,DWORD * value) const516 HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD* value) const {
517   ASSERT(value != NULL);
518 
519   DWORD type = 0;
520   DWORD byte_count = sizeof(DWORD);
521   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
522                               value, &byte_count);
523   HRESULT hr = HRESULT_FROM_WIN32(res);
524   ASSERT((hr != S_OK) || (type == REG_DWORD));
525   ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD)));
526   return hr;
527 }
528 
529 // Int64 Get
GetValue(const wchar_t * value_name,DWORD64 * value) const530 HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD64* value) const {
531   ASSERT(value != NULL);
532 
533   DWORD type = 0;
534   DWORD byte_count = sizeof(DWORD64);
535   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
536                               value, &byte_count);
537   HRESULT hr = HRESULT_FROM_WIN32(res);
538   ASSERT((hr != S_OK) || (type == REG_QWORD));
539   ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD64)));
540   return hr;
541 }
542 
543 // String Get
GetValue(const wchar_t * value_name,wchar_t ** value) const544 HRESULT RegKey::GetValue(const wchar_t* value_name, wchar_t** value) const {
545   ASSERT(value != NULL);
546 
547   DWORD byte_count = 0;
548   DWORD type = 0;
549 
550   // first get the size of the string buffer
551   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
552                               &type, NULL, &byte_count);
553   HRESULT hr = HRESULT_FROM_WIN32(res);
554 
555   if (hr == S_OK) {
556     // allocate room for the string and a terminating \0
557     *value = new wchar_t[(byte_count / sizeof(wchar_t)) + 1];
558 
559     if ((*value) != NULL) {
560       if (byte_count != 0) {
561         // make the call again
562         res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
563                                *value, &byte_count);
564         hr = HRESULT_FROM_WIN32(res);
565       } else {
566         (*value)[0] = L'\0';
567       }
568 
569       ASSERT((hr != S_OK) || (type == REG_SZ) ||
570              (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
571     } else {
572       hr = E_OUTOFMEMORY;
573     }
574   }
575 
576   return hr;
577 }
578 
579 // get a string value
GetValue(const wchar_t * value_name,std::wstring * value) const580 HRESULT RegKey::GetValue(const wchar_t* value_name, std::wstring* value) const {
581   ASSERT(value != NULL);
582 
583   DWORD byte_count = 0;
584   DWORD type = 0;
585 
586   // first get the size of the string buffer
587   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
588                               &type, NULL, &byte_count);
589   HRESULT hr = HRESULT_FROM_WIN32(res);
590 
591   if (hr == S_OK) {
592     if (byte_count != 0) {
593       // Allocate some memory and make the call again
594       value->resize(byte_count / sizeof(wchar_t) + 1);
595       res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
596                              &value->at(0), &byte_count);
597       hr = HRESULT_FROM_WIN32(res);
598       value->resize(wcslen(value->data()));
599     } else {
600       value->clear();
601     }
602 
603     ASSERT((hr != S_OK) || (type == REG_SZ) ||
604            (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
605   }
606 
607   return hr;
608 }
609 
610 // convert REG_MULTI_SZ bytes to string array
MultiSZBytesToStringArray(const uint8_t * buffer,DWORD byte_count,std::vector<std::wstring> * value)611 HRESULT RegKey::MultiSZBytesToStringArray(const uint8_t* buffer,
612                                           DWORD byte_count,
613                                           std::vector<std::wstring>* value) {
614   ASSERT(buffer != NULL);
615   ASSERT(value != NULL);
616 
617   const wchar_t* data = reinterpret_cast<const wchar_t*>(buffer);
618   DWORD data_len = byte_count / sizeof(wchar_t);
619   value->clear();
620   if (data_len > 1) {
621     // must be terminated by two null characters
622     if (data[data_len - 1] != 0 || data[data_len - 2] != 0) {
623       return E_INVALIDARG;
624     }
625 
626     // put null-terminated strings into arrays
627     while (*data) {
628       std::wstring str(data);
629       value->push_back(str);
630       data += str.length() + 1;
631     }
632   }
633   return S_OK;
634 }
635 
636 // get a std::vector<std::wstring> value from REG_MULTI_SZ type
GetValue(const wchar_t * value_name,std::vector<std::wstring> * value) const637 HRESULT RegKey::GetValue(const wchar_t* value_name,
638                          std::vector<std::wstring>* value) const {
639   ASSERT(value != NULL);
640 
641   DWORD byte_count = 0;
642   DWORD type = 0;
643   uint8_t* buffer = 0;
644 
645   // first get the size of the buffer
646   HRESULT hr = GetValueHelper(value_name, &type, &buffer, &byte_count);
647   ASSERT((hr != S_OK) || (type == REG_MULTI_SZ));
648 
649   if (SUCCEEDED(hr)) {
650     hr = MultiSZBytesToStringArray(buffer, byte_count, value);
651   }
652 
653   return hr;
654 }
655 
656 // Binary data Get
GetValue(const wchar_t * value_name,uint8_t ** value,DWORD * byte_count) const657 HRESULT RegKey::GetValue(const wchar_t* value_name,
658                          uint8_t** value,
659                          DWORD* byte_count) const {
660   ASSERT(byte_count != NULL);
661   ASSERT(value != NULL);
662 
663   DWORD type = 0;
664   HRESULT hr = GetValueHelper(value_name, &type, value, byte_count);
665   ASSERT((hr != S_OK) || (type == REG_MULTI_SZ) || (type == REG_BINARY));
666   return hr;
667 }
668 
669 // Raw data get
GetValue(const wchar_t * value_name,uint8_t ** value,DWORD * byte_count,DWORD * type) const670 HRESULT RegKey::GetValue(const wchar_t* value_name,
671                          uint8_t** value,
672                          DWORD* byte_count,
673                          DWORD* type) const {
674   ASSERT(type != NULL);
675   ASSERT(byte_count != NULL);
676   ASSERT(value != NULL);
677 
678   return GetValueHelper(value_name, type, value, byte_count);
679 }
680 
681 // Int32 set
SetValue(const wchar_t * value_name,DWORD value) const682 HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD value) const {
683   ASSERT(h_key_ != NULL);
684 
685   LONG res =
686       ::RegSetValueEx(h_key_, value_name, NULL, REG_DWORD,
687                       reinterpret_cast<const uint8_t*>(&value), sizeof(DWORD));
688   return HRESULT_FROM_WIN32(res);
689 }
690 
691 // Int64 set
SetValue(const wchar_t * value_name,DWORD64 value) const692 HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD64 value) const {
693   ASSERT(h_key_ != NULL);
694 
695   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_QWORD,
696                              reinterpret_cast<const uint8_t*>(&value),
697                              sizeof(DWORD64));
698   return HRESULT_FROM_WIN32(res);
699 }
700 
701 // String set
SetValue(const wchar_t * value_name,const wchar_t * value) const702 HRESULT RegKey::SetValue(const wchar_t* value_name,
703                          const wchar_t* value) const {
704   ASSERT(value != NULL);
705   ASSERT(h_key_ != NULL);
706 
707   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_SZ,
708                              reinterpret_cast<const uint8_t*>(value),
709                              (lstrlen(value) + 1) * sizeof(wchar_t));
710   return HRESULT_FROM_WIN32(res);
711 }
712 
713 // Binary data set
SetValue(const wchar_t * value_name,const uint8_t * value,DWORD byte_count) const714 HRESULT RegKey::SetValue(const wchar_t* value_name,
715                          const uint8_t* value,
716                          DWORD byte_count) const {
717   ASSERT(h_key_ != NULL);
718 
719   // special case - if 'value' is NULL make sure byte_count is zero
720   if (value == NULL) {
721     byte_count = 0;
722   }
723 
724   LONG res = ::RegSetValueEx(h_key_, value_name, NULL,
725                              REG_BINARY, value, byte_count);
726   return HRESULT_FROM_WIN32(res);
727 }
728 
729 // Raw data set
SetValue(const wchar_t * value_name,const uint8_t * value,DWORD byte_count,DWORD type) const730 HRESULT RegKey::SetValue(const wchar_t* value_name,
731                          const uint8_t* value,
732                          DWORD byte_count,
733                          DWORD type) const {
734   ASSERT(value != NULL);
735   ASSERT(h_key_ != NULL);
736 
737   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, type, value, byte_count);
738   return HRESULT_FROM_WIN32(res);
739 }
740 
HasKey(const wchar_t * full_key_name)741 bool RegKey::HasKey(const wchar_t* full_key_name) {
742   ASSERT(full_key_name != NULL);
743 
744   // get the root HKEY
745   std::wstring key_name(full_key_name);
746   HKEY h_key = GetRootKeyInfo(&key_name);
747 
748   if (h_key != NULL) {
749     RegKey key;
750     HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
751     key.Close();
752     return S_OK == hr;
753   }
754   return false;
755 }
756 
757 // static version of HasValue
HasValue(const wchar_t * full_key_name,const wchar_t * value_name)758 bool RegKey::HasValue(const wchar_t* full_key_name, const wchar_t* value_name) {
759   ASSERT(full_key_name != NULL);
760 
761   bool has_value = false;
762   // get the root HKEY
763   std::wstring key_name(full_key_name);
764   HKEY h_key = GetRootKeyInfo(&key_name);
765 
766   if (h_key != NULL) {
767     RegKey key;
768     if (key.Open(h_key, key_name.c_str(), KEY_READ) == S_OK) {
769       has_value = key.HasValue(value_name);
770       key.Close();
771     }
772   }
773   return has_value;
774 }
775 
GetValueType(const wchar_t * full_key_name,const wchar_t * value_name,DWORD * value_type)776 HRESULT RegKey::GetValueType(const wchar_t* full_key_name,
777                              const wchar_t* value_name,
778                              DWORD* value_type) {
779   ASSERT(full_key_name != NULL);
780   ASSERT(value_type != NULL);
781 
782   *value_type = REG_NONE;
783 
784   std::wstring key_name(full_key_name);
785   HKEY h_key = GetRootKeyInfo(&key_name);
786 
787   RegKey key;
788   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
789   if (SUCCEEDED(hr)) {
790     LONG res = ::SHQueryValueEx(key.h_key_, value_name, NULL, value_type,
791                                 NULL, NULL);
792     if (res != ERROR_SUCCESS) {
793       hr = HRESULT_FROM_WIN32(res);
794     }
795   }
796 
797   return hr;
798 }
799 
DeleteKey(const wchar_t * full_key_name)800 HRESULT RegKey::DeleteKey(const wchar_t* full_key_name) {
801   ASSERT(full_key_name != NULL);
802 
803   return DeleteKey(full_key_name, true);
804 }
805 
DeleteKey(const wchar_t * full_key_name,bool recursively)806 HRESULT RegKey::DeleteKey(const wchar_t* full_key_name, bool recursively) {
807   ASSERT(full_key_name != NULL);
808 
809   // need to open the parent key first
810   // get the root HKEY
811   std::wstring key_name(full_key_name);
812   HKEY h_key = GetRootKeyInfo(&key_name);
813 
814   // get the parent key
815   std::wstring parent_key(GetParentKeyInfo(&key_name));
816 
817   RegKey key;
818   HRESULT hr = key.Open(h_key, parent_key.c_str());
819 
820   if (hr == S_OK) {
821     hr = recursively ? key.RecurseDeleteSubKey(key_name.c_str())
822                      : key.DeleteSubKey(key_name.c_str());
823   } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
824              hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
825     hr = S_FALSE;
826   }
827 
828   key.Close();
829   return hr;
830 }
831 
DeleteValue(const wchar_t * full_key_name,const wchar_t * value_name)832 HRESULT RegKey::DeleteValue(const wchar_t* full_key_name,
833                             const wchar_t* value_name) {
834   ASSERT(full_key_name != NULL);
835 
836   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
837   // get the root HKEY
838   std::wstring key_name(full_key_name);
839   HKEY h_key = GetRootKeyInfo(&key_name);
840 
841   if (h_key != NULL) {
842     RegKey key;
843     hr = key.Open(h_key, key_name.c_str());
844     if (hr == S_OK) {
845       hr = key.DeleteValue(value_name);
846       key.Close();
847     }
848   }
849   return hr;
850 }
851 
RecurseDeleteSubKey(const wchar_t * key_name)852 HRESULT RegKey::RecurseDeleteSubKey(const wchar_t* key_name) {
853   ASSERT(key_name != NULL);
854 
855   RegKey key;
856   HRESULT hr = key.Open(h_key_, key_name);
857 
858   if (hr == S_OK) {
859     // enumerate all subkeys of this key and recursivelly delete them
860     FILETIME time = {0};
861     wchar_t key_name_buf[kMaxKeyNameChars] = {0};
862     DWORD key_name_buf_size = kMaxKeyNameChars;
863     while (hr == S_OK &&
864         ::RegEnumKeyEx(key.h_key_, 0, key_name_buf, &key_name_buf_size,
865                        NULL, NULL, NULL,  &time) == ERROR_SUCCESS) {
866       hr = key.RecurseDeleteSubKey(key_name_buf);
867 
868       // restore the buffer size
869       key_name_buf_size = kMaxKeyNameChars;
870     }
871     // close the top key
872     key.Close();
873   }
874 
875   if (hr == S_OK) {
876     // the key has no more children keys
877     // delete the key and all of its values
878     hr = DeleteSubKey(key_name);
879   }
880 
881   return hr;
882 }
883 
GetRootKeyInfo(std::wstring * full_key_name)884 HKEY RegKey::GetRootKeyInfo(std::wstring* full_key_name) {
885   ASSERT(full_key_name != NULL);
886 
887   HKEY h_key = NULL;
888   // get the root HKEY
889   size_t index = full_key_name->find(L'\\');
890   std::wstring root_key;
891 
892   if (index == -1) {
893     root_key = *full_key_name;
894     *full_key_name = L"";
895   } else {
896     root_key = full_key_name->substr(0, index);
897     *full_key_name = full_key_name->substr(index + 1,
898                                            full_key_name->length() - index - 1);
899   }
900 
901   for (std::wstring::iterator iter = root_key.begin();
902        iter != root_key.end(); ++iter) {
903     *iter = toupper(*iter);
904   }
905 
906   if (!root_key.compare(L"HKLM") ||
907       !root_key.compare(L"HKEY_LOCAL_MACHINE")) {
908     h_key = HKEY_LOCAL_MACHINE;
909   } else if (!root_key.compare(L"HKCU") ||
910              !root_key.compare(L"HKEY_CURRENT_USER")) {
911     h_key = HKEY_CURRENT_USER;
912   } else if (!root_key.compare(L"HKU") ||
913              !root_key.compare(L"HKEY_USERS")) {
914     h_key = HKEY_USERS;
915   } else if (!root_key.compare(L"HKCR") ||
916              !root_key.compare(L"HKEY_CLASSES_ROOT")) {
917     h_key = HKEY_CLASSES_ROOT;
918   }
919 
920   return h_key;
921 }
922 
923 
924 // Returns true if this key name is 'safe' for deletion
925 // (doesn't specify a key root)
SafeKeyNameForDeletion(const wchar_t * key_name)926 bool RegKey::SafeKeyNameForDeletion(const wchar_t* key_name) {
927   ASSERT(key_name != NULL);
928   std::wstring key(key_name);
929 
930   HKEY root_key = GetRootKeyInfo(&key);
931 
932   if (!root_key) {
933     key = key_name;
934   }
935   if (key.empty()) {
936     return false;
937   }
938   bool found_subkey = false, backslash_found = false;
939   for (size_t i = 0 ; i < key.length() ; ++i) {
940     if (key[i] == L'\\') {
941       backslash_found = true;
942     } else if (backslash_found) {
943       found_subkey = true;
944       break;
945     }
946   }
947   return (root_key == HKEY_USERS) ? found_subkey : true;
948 }
949 
GetParentKeyInfo(std::wstring * key_name)950 std::wstring RegKey::GetParentKeyInfo(std::wstring* key_name) {
951   ASSERT(key_name != NULL);
952 
953   // get the parent key
954   size_t index = key_name->rfind(L'\\');
955   std::wstring parent_key;
956   if (index == -1) {
957     parent_key = L"";
958   } else {
959     parent_key = key_name->substr(0, index);
960     *key_name = key_name->substr(index + 1, key_name->length() - index - 1);
961   }
962 
963   return parent_key;
964 }
965 
966 // get the number of values for this key
GetValueCount()967 uint32_t RegKey::GetValueCount() {
968   DWORD num_values = 0;
969 
970   if (ERROR_SUCCESS != ::RegQueryInfoKey(
971         h_key_,  // key handle
972         NULL,  // buffer for class name
973         NULL,  // size of class string
974         NULL,  // reserved
975         NULL,  // number of subkeys
976         NULL,  // longest subkey size
977         NULL,  // longest class string
978         &num_values,  // number of values for this key
979         NULL,  // longest value name
980         NULL,  // longest value data
981         NULL,  // security descriptor
982         NULL)) {  // last write time
983     ASSERT(false);
984   }
985   return num_values;
986 }
987 
988 // Enumerators for the value_names for this key
989 
990 // Called to get the value name for the given value name index
991 // Use GetValueCount() to get the total value_name count for this key
992 // Returns failure if no key at the specified index
GetValueNameAt(int index,std::wstring * value_name,DWORD * type)993 HRESULT RegKey::GetValueNameAt(int index, std::wstring* value_name,
994                                DWORD* type) {
995   ASSERT(value_name != NULL);
996 
997   LONG res = ERROR_SUCCESS;
998   wchar_t value_name_buf[kMaxValueNameChars] = {0};
999   DWORD value_name_buf_size = kMaxValueNameChars;
1000   res = ::RegEnumValue(h_key_, index, value_name_buf, &value_name_buf_size,
1001                        NULL, type, NULL, NULL);
1002 
1003   if (res == ERROR_SUCCESS) {
1004     value_name->assign(value_name_buf);
1005   }
1006 
1007   return HRESULT_FROM_WIN32(res);
1008 }
1009 
GetSubkeyCount()1010 uint32_t RegKey::GetSubkeyCount() {
1011   // number of values for key
1012   DWORD num_subkeys = 0;
1013 
1014   if (ERROR_SUCCESS != ::RegQueryInfoKey(
1015           h_key_,  // key handle
1016           NULL,  // buffer for class name
1017           NULL,  // size of class string
1018           NULL,  // reserved
1019           &num_subkeys,  // number of subkeys
1020           NULL,  // longest subkey size
1021           NULL,  // longest class string
1022           NULL,  // number of values for this key
1023           NULL,  // longest value name
1024           NULL,  // longest value data
1025           NULL,  // security descriptor
1026           NULL)) { // last write time
1027     ASSERT(false);
1028   }
1029   return num_subkeys;
1030 }
1031 
GetSubkeyNameAt(int index,std::wstring * key_name)1032 HRESULT RegKey::GetSubkeyNameAt(int index, std::wstring* key_name) {
1033   ASSERT(key_name != NULL);
1034 
1035   LONG res = ERROR_SUCCESS;
1036   wchar_t key_name_buf[kMaxKeyNameChars] = {0};
1037   DWORD key_name_buf_size = kMaxKeyNameChars;
1038 
1039   res = ::RegEnumKeyEx(h_key_, index, key_name_buf, &key_name_buf_size,
1040                        NULL, NULL, NULL, NULL);
1041 
1042   if (res == ERROR_SUCCESS) {
1043     key_name->assign(key_name_buf);
1044   }
1045 
1046   return HRESULT_FROM_WIN32(res);
1047 }
1048 
1049 // Is the key empty: having no sub-keys and values
IsKeyEmpty(const wchar_t * full_key_name)1050 bool RegKey::IsKeyEmpty(const wchar_t* full_key_name) {
1051   ASSERT(full_key_name != NULL);
1052 
1053   bool is_empty = true;
1054 
1055   // Get the root HKEY
1056   std::wstring key_name(full_key_name);
1057   HKEY h_key = GetRootKeyInfo(&key_name);
1058 
1059   // Open the key to check
1060   if (h_key != NULL) {
1061     RegKey key;
1062     HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
1063     if (SUCCEEDED(hr)) {
1064       is_empty = key.GetSubkeyCount() == 0 && key.GetValueCount() == 0;
1065       key.Close();
1066     }
1067   }
1068 
1069   return is_empty;
1070 }
1071 
AdjustCurrentProcessPrivilege(const TCHAR * privilege,bool to_enable)1072 bool AdjustCurrentProcessPrivilege(const TCHAR* privilege, bool to_enable) {
1073   ASSERT(privilege != NULL);
1074 
1075   bool ret = false;
1076   HANDLE token;
1077   if (::OpenProcessToken(::GetCurrentProcess(),
1078                          TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &token)) {
1079     LUID luid;
1080     memset(&luid, 0, sizeof(luid));
1081     if (::LookupPrivilegeValue(NULL, privilege, &luid)) {
1082       TOKEN_PRIVILEGES privs;
1083       privs.PrivilegeCount = 1;
1084       privs.Privileges[0].Luid = luid;
1085       privs.Privileges[0].Attributes = to_enable ? SE_PRIVILEGE_ENABLED : 0;
1086       if (::AdjustTokenPrivileges(token, FALSE, &privs, 0, NULL, 0)) {
1087         ret = true;
1088       } else {
1089         LOG_GLE(LS_ERROR) << "AdjustTokenPrivileges failed";
1090       }
1091     } else {
1092       LOG_GLE(LS_ERROR) << "LookupPrivilegeValue failed";
1093     }
1094     CloseHandle(token);
1095   } else {
1096     LOG_GLE(LS_ERROR) << "OpenProcessToken(GetCurrentProcess) failed";
1097   }
1098 
1099   return ret;
1100 }
1101 
1102 }  // namespace rtc
1103