1 /* -*- Mode: C; tab-width: 4 -*-
2  *
3  * Copyright (c) 2002-2004 Apple Computer, Inc. All rights reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include "Secret.h"
19 #include <stdarg.h>
20 #include <stddef.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 #include <winsock2.h>
25 #include <ws2tcpip.h>
26 #include <windows.h>
27 #include <process.h>
28 #include <ntsecapi.h>
29 #include <lm.h>
30 #include "DebugServices.h"
31 
32 
33 mDNSlocal OSStatus MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input );
34 mDNSlocal OSStatus MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input );
35 
36 
37 BOOL
LsaGetSecret(const char * inDomain,char * outDomain,unsigned outDomainSize,char * outKey,unsigned outKeySize,char * outSecret,unsigned outSecretSize)38 LsaGetSecret( const char * inDomain, char * outDomain, unsigned outDomainSize, char * outKey, unsigned outKeySize, char * outSecret, unsigned outSecretSize )
39 {
40 	PLSA_UNICODE_STRING		domainLSA;
41 	PLSA_UNICODE_STRING		keyLSA;
42 	PLSA_UNICODE_STRING		secretLSA;
43 	size_t					i;
44 	size_t					dlen;
45 	LSA_OBJECT_ATTRIBUTES	attrs;
46 	LSA_HANDLE				handle = NULL;
47 	NTSTATUS				res;
48 	OSStatus				err;
49 
50 	check( inDomain );
51 	check( outDomain );
52 	check( outKey );
53 	check( outSecret );
54 
55 	// Initialize
56 
57 	domainLSA	= NULL;
58 	keyLSA		= NULL;
59 	secretLSA	= NULL;
60 
61 	// Make sure we have enough space to add trailing dot
62 
63 	dlen = strlen( inDomain );
64 	err = strcpy_s( outDomain, outDomainSize - 2, inDomain );
65 	require_noerr( err, exit );
66 
67 	// If there isn't a trailing dot, add one because the mDNSResponder
68 	// presents names with the trailing dot.
69 
70 	if ( outDomain[ dlen - 1 ] != '.' )
71 	{
72 		outDomain[ dlen++ ] = '.';
73 		outDomain[ dlen ] = '\0';
74 	}
75 
76 	// Canonicalize name by converting to lower case (keychain and some name servers are case sensitive)
77 
78 	for ( i = 0; i < dlen; i++ )
79 	{
80 		outDomain[i] = (char) tolower( outDomain[i] );  // canonicalize -> lower case
81 	}
82 
83 	// attrs are reserved, so initialize to zeroes.
84 
85 	ZeroMemory( &attrs, sizeof( attrs ) );
86 
87 	// Get a handle to the Policy object on the local system
88 
89 	res = LsaOpenPolicy( NULL, &attrs, POLICY_GET_PRIVATE_INFORMATION, &handle );
90 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
91 	require_noerr( err, exit );
92 
93 	// Get the encrypted data
94 
95 	domainLSA = ( PLSA_UNICODE_STRING ) malloc( sizeof( LSA_UNICODE_STRING ) );
96 	require_action( domainLSA != NULL, exit, err = mStatus_NoMemoryErr );
97 	err = MakeLsaStringFromUTF8String( domainLSA, outDomain );
98 	require_noerr( err, exit );
99 
100 	// Retrieve the key
101 
102 	res = LsaRetrievePrivateData( handle, domainLSA, &keyLSA );
103 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
104 	require_noerr_quiet( err, exit );
105 
106 	// <rdar://problem/4192119> Lsa secrets use a flat naming space.  Therefore, we will prepend "$" to the keyname to
107 	// make sure it doesn't conflict with a zone name.
108 	// Strip off the "$" prefix.
109 
110 	err = MakeUTF8StringFromLsaString( outKey, outKeySize, keyLSA );
111 	require_noerr( err, exit );
112 	require_action( outKey[0] == '$', exit, err = kUnknownErr );
113 	memcpy( outKey, outKey + 1, strlen( outKey ) );
114 
115 	// Retrieve the secret
116 
117 	res = LsaRetrievePrivateData( handle, keyLSA, &secretLSA );
118 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
119 	require_noerr_quiet( err, exit );
120 
121 	// Convert the secret to UTF8 string
122 
123 	err = MakeUTF8StringFromLsaString( outSecret, outSecretSize, secretLSA );
124 	require_noerr( err, exit );
125 
126 exit:
127 
128 	if ( domainLSA != NULL )
129 	{
130 		if ( domainLSA->Buffer != NULL )
131 		{
132 			free( domainLSA->Buffer );
133 		}
134 
135 		free( domainLSA );
136 	}
137 
138 	if ( keyLSA != NULL )
139 	{
140 		LsaFreeMemory( keyLSA );
141 	}
142 
143 	if ( secretLSA != NULL )
144 	{
145 		LsaFreeMemory( secretLSA );
146 	}
147 
148 	if ( handle )
149 	{
150 		LsaClose( handle );
151 		handle = NULL;
152 	}
153 
154 	return ( !err ) ? TRUE : FALSE;
155 }
156 
157 
158 mDNSBool
LsaSetSecret(const char * inDomain,const char * inKey,const char * inSecret)159 LsaSetSecret( const char * inDomain, const char * inKey, const char * inSecret )
160 {
161 	size_t					inDomainLength;
162 	size_t					inKeyLength;
163 	char					domain[ 1024 ];
164 	char					key[ 1024 ];
165 	LSA_OBJECT_ATTRIBUTES	attrs;
166 	LSA_HANDLE				handle = NULL;
167 	NTSTATUS				res;
168 	LSA_UNICODE_STRING		lucZoneName;
169 	LSA_UNICODE_STRING		lucKeyName;
170 	LSA_UNICODE_STRING		lucSecretName;
171 	BOOL					ok = TRUE;
172 	OSStatus				err;
173 
174 	require_action( inDomain != NULL, exit, ok = FALSE );
175 	require_action( inKey != NULL, exit, ok = FALSE );
176 	require_action( inSecret != NULL, exit, ok = FALSE );
177 
178 	// If there isn't a trailing dot, add one because the mDNSResponder
179 	// presents names with the trailing dot.
180 
181 	ZeroMemory( domain, sizeof( domain ) );
182 	inDomainLength = strlen( inDomain );
183 	require_action( inDomainLength > 0, exit, ok = FALSE );
184 	err = strcpy_s( domain, sizeof( domain ) - 2, inDomain );
185 	require_action( !err, exit, ok = FALSE );
186 
187 	if ( domain[ inDomainLength - 1 ] != '.' )
188 	{
189 		domain[ inDomainLength++ ] = '.';
190 		domain[ inDomainLength ] = '\0';
191 	}
192 
193 	// <rdar://problem/4192119>
194 	//
195 	// Prepend "$" to the key name, so that there will
196 	// be no conflict between the zone name and the key
197 	// name
198 
199 	ZeroMemory( key, sizeof( key ) );
200 	inKeyLength = strlen( inKey );
201 	require_action( inKeyLength > 0 , exit, ok = FALSE );
202 	key[ 0 ] = '$';
203 	err = strcpy_s( key + 1, sizeof( key ) - 3, inKey );
204 	require_action( !err, exit, ok = FALSE );
205 	inKeyLength++;
206 
207 	if ( key[ inKeyLength - 1 ] != '.' )
208 	{
209 		key[ inKeyLength++ ] = '.';
210 		key[ inKeyLength ] = '\0';
211 	}
212 
213 	// attrs are reserved, so initialize to zeroes.
214 
215 	ZeroMemory( &attrs, sizeof( attrs ) );
216 
217 	// Get a handle to the Policy object on the local system
218 
219 	res = LsaOpenPolicy( NULL, &attrs, POLICY_ALL_ACCESS, &handle );
220 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
221 	require_noerr( err, exit );
222 
223 	// Intializing PLSA_UNICODE_STRING structures
224 
225 	err = MakeLsaStringFromUTF8String( &lucZoneName, domain );
226 	require_noerr( err, exit );
227 
228 	err = MakeLsaStringFromUTF8String( &lucKeyName, key );
229 	require_noerr( err, exit );
230 
231 	err = MakeLsaStringFromUTF8String( &lucSecretName, inSecret );
232 	require_noerr( err, exit );
233 
234 	// Store the private data.
235 
236 	res = LsaStorePrivateData( handle, &lucZoneName, &lucKeyName );
237 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
238 	require_noerr( err, exit );
239 
240 	res = LsaStorePrivateData( handle, &lucKeyName, &lucSecretName );
241 	err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
242 	require_noerr( err, exit );
243 
244 exit:
245 
246 	if ( handle )
247 	{
248 		LsaClose( handle );
249 		handle = NULL;
250 	}
251 
252 	return ok;
253 }
254 
255 
256 //===========================================================================================================================
257 //	MakeLsaStringFromUTF8String
258 //===========================================================================================================================
259 
260 mDNSlocal OSStatus
MakeLsaStringFromUTF8String(PLSA_UNICODE_STRING output,const char * input)261 MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input )
262 {
263 	int			size;
264 	OSStatus	err;
265 
266 	check( input );
267 	check( output );
268 
269 	output->Buffer = NULL;
270 
271 	size = MultiByteToWideChar( CP_UTF8, 0, input, -1, NULL, 0 );
272 	err = translate_errno( size > 0, GetLastError(), kUnknownErr );
273 	require_noerr( err, exit );
274 
275 	output->Length = (USHORT)( size * sizeof( wchar_t ) );
276 	output->Buffer = (PWCHAR) malloc( output->Length );
277 	require_action( output->Buffer, exit, err = mStatus_NoMemoryErr );
278 	size = MultiByteToWideChar( CP_UTF8, 0, input, -1, output->Buffer, size );
279 	err = translate_errno( size > 0, GetLastError(), kUnknownErr );
280 	require_noerr( err, exit );
281 
282 	// We're going to subtrace one wchar_t from the size, because we didn't
283 	// include it when we encoded the string
284 
285 	output->MaximumLength = output->Length;
286 	output->Length		-= sizeof( wchar_t );
287 
288 exit:
289 
290 	if ( err && output->Buffer )
291 	{
292 		free( output->Buffer );
293 		output->Buffer = NULL;
294 	}
295 
296 	return( err );
297 }
298 
299 
300 
301 //===========================================================================================================================
302 //	MakeUTF8StringFromLsaString
303 //===========================================================================================================================
304 
305 mDNSlocal OSStatus
MakeUTF8StringFromLsaString(char * output,size_t len,PLSA_UNICODE_STRING input)306 MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input )
307 {
308 	size_t		size;
309 	OSStatus	err = kNoErr;
310 
311 	// The Length field of this structure holds the number of bytes,
312 	// but WideCharToMultiByte expects the number of wchar_t's. So
313 	// we divide by sizeof(wchar_t) to get the correct number.
314 
315 	size = (size_t) WideCharToMultiByte(CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), NULL, 0, NULL, NULL);
316 	err = translate_errno( size != 0, GetLastError(), kUnknownErr );
317 	require_noerr( err, exit );
318 
319 	// Ensure that we have enough space (Add one for trailing '\0')
320 
321 	require_action( ( size + 1 ) <= len, exit, err = mStatus_NoMemoryErr );
322 
323 	// Convert the string
324 
325 	size = (size_t) WideCharToMultiByte( CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), output, (int) size, NULL, NULL);
326 	err = translate_errno( size != 0, GetLastError(), kUnknownErr );
327 	require_noerr( err, exit );
328 
329 	// have to add the trailing 0 because WideCharToMultiByte doesn't do it,
330 	// although it does return the correct size
331 
332 	output[size] = '\0';
333 
334 exit:
335 
336 	return err;
337 }
338 
339