1 /* Written by Dr Stephen N Henson (steve@openssl.org) for the OpenSSL
2  * project 2005.
3  */
4 /* ====================================================================
5  * Copyright (c) 2005 The OpenSSL Project.  All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in
16  *    the documentation and/or other materials provided with the
17  *    distribution.
18  *
19  * 3. All advertising materials mentioning features or use of this
20  *    software must display the following acknowledgment:
21  *    "This product includes software developed by the OpenSSL Project
22  *    for use in the OpenSSL Toolkit. (http://www.OpenSSL.org/)"
23  *
24  * 4. The names "OpenSSL Toolkit" and "OpenSSL Project" must not be used to
25  *    endorse or promote products derived from this software without
26  *    prior written permission. For written permission, please contact
27  *    licensing@OpenSSL.org.
28  *
29  * 5. Products derived from this software may not be called "OpenSSL"
30  *    nor may "OpenSSL" appear in their names without prior written
31  *    permission of the OpenSSL Project.
32  *
33  * 6. Redistributions of any form whatsoever must retain the following
34  *    acknowledgment:
35  *    "This product includes software developed by the OpenSSL Project
36  *    for use in the OpenSSL Toolkit (http://www.OpenSSL.org/)"
37  *
38  * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY
39  * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
40  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
41  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE OpenSSL PROJECT OR
42  * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
43  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
44  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
45  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
46  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
47  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
48  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
49  * OF THE POSSIBILITY OF SUCH DAMAGE.
50  * ====================================================================
51  *
52  * This product includes cryptographic software written by Eric Young
53  * (eay@cryptsoft.com).  This product includes software written by Tim
54  * Hudson (tjh@cryptsoft.com). */
55 
56 #include <openssl/rsa.h>
57 
58 #include <assert.h>
59 #include <limits.h>
60 #include <string.h>
61 
62 #include <openssl/bn.h>
63 #include <openssl/digest.h>
64 #include <openssl/err.h>
65 #include <openssl/mem.h>
66 #include <openssl/rand.h>
67 #include <openssl/sha.h>
68 
69 #include "internal.h"
70 #include "../internal.h"
71 
72 /* TODO(fork): don't the check functions have to be constant time? */
73 
RSA_padding_add_PKCS1_type_1(uint8_t * to,size_t to_len,const uint8_t * from,size_t from_len)74 int RSA_padding_add_PKCS1_type_1(uint8_t *to, size_t to_len,
75                                  const uint8_t *from, size_t from_len) {
76   /* See RFC 8017, section 9.2. */
77   if (to_len < RSA_PKCS1_PADDING_SIZE) {
78     OPENSSL_PUT_ERROR(RSA, RSA_R_KEY_SIZE_TOO_SMALL);
79     return 0;
80   }
81 
82   if (from_len > to_len - RSA_PKCS1_PADDING_SIZE) {
83     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
84     return 0;
85   }
86 
87   to[0] = 0;
88   to[1] = 1;
89   OPENSSL_memset(to + 2, 0xff, to_len - 3 - from_len);
90   to[to_len - from_len - 1] = 0;
91   OPENSSL_memcpy(to + to_len - from_len, from, from_len);
92   return 1;
93 }
94 
RSA_padding_check_PKCS1_type_1(uint8_t * to,unsigned to_len,const uint8_t * from,unsigned from_len)95 int RSA_padding_check_PKCS1_type_1(uint8_t *to, unsigned to_len,
96                                    const uint8_t *from, unsigned from_len) {
97   unsigned i, j;
98   const uint8_t *p;
99 
100   if (from_len < 2) {
101     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_SMALL);
102     return -1;
103   }
104 
105   p = from;
106   if ((*(p++) != 0) || (*(p++) != 1)) {
107     OPENSSL_PUT_ERROR(RSA, RSA_R_BLOCK_TYPE_IS_NOT_01);
108     return -1;
109   }
110 
111   /* scan over padding data */
112   j = from_len - 2; /* one for leading 00, one for type. */
113   for (i = 0; i < j; i++) {
114     /* should decrypt to 0xff */
115     if (*p != 0xff) {
116       if (*p == 0) {
117         p++;
118         break;
119       } else {
120         OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_FIXED_HEADER_DECRYPT);
121         return -1;
122       }
123     }
124     p++;
125   }
126 
127   if (i == j) {
128     OPENSSL_PUT_ERROR(RSA, RSA_R_NULL_BEFORE_BLOCK_MISSING);
129     return -1;
130   }
131 
132   if (i < 8) {
133     OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_PAD_BYTE_COUNT);
134     return -1;
135   }
136   i++; /* Skip over the '\0' */
137   j -= i;
138   if (j > to_len) {
139     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE);
140     return -1;
141   }
142   OPENSSL_memcpy(to, p, j);
143 
144   return j;
145 }
146 
rand_nonzero(uint8_t * out,size_t len)147 static int rand_nonzero(uint8_t *out, size_t len) {
148   if (!RAND_bytes(out, len)) {
149     return 0;
150   }
151 
152   for (size_t i = 0; i < len; i++) {
153     while (out[i] == 0) {
154       if (!RAND_bytes(out + i, 1)) {
155         return 0;
156       }
157     }
158   }
159 
160   return 1;
161 }
162 
RSA_padding_add_PKCS1_type_2(uint8_t * to,size_t to_len,const uint8_t * from,size_t from_len)163 int RSA_padding_add_PKCS1_type_2(uint8_t *to, size_t to_len,
164                                  const uint8_t *from, size_t from_len) {
165   /* See RFC 8017, section 7.2.1. */
166   if (to_len < RSA_PKCS1_PADDING_SIZE) {
167     OPENSSL_PUT_ERROR(RSA, RSA_R_KEY_SIZE_TOO_SMALL);
168     return 0;
169   }
170 
171   if (from_len > to_len - RSA_PKCS1_PADDING_SIZE) {
172     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
173     return 0;
174   }
175 
176   to[0] = 0;
177   to[1] = 2;
178 
179   size_t padding_len = to_len - 3 - from_len;
180   if (!rand_nonzero(to + 2, padding_len)) {
181     return 0;
182   }
183 
184   to[2 + padding_len] = 0;
185   OPENSSL_memcpy(to + to_len - from_len, from, from_len);
186   return 1;
187 }
188 
RSA_padding_check_PKCS1_type_2(uint8_t * to,unsigned to_len,const uint8_t * from,unsigned from_len)189 int RSA_padding_check_PKCS1_type_2(uint8_t *to, unsigned to_len,
190                                    const uint8_t *from, unsigned from_len) {
191   if (from_len == 0) {
192     OPENSSL_PUT_ERROR(RSA, RSA_R_EMPTY_PUBLIC_KEY);
193     return -1;
194   }
195 
196   /* PKCS#1 v1.5 decryption. See "PKCS #1 v2.2: RSA Cryptography
197    * Standard", section 7.2.2. */
198   if (from_len < RSA_PKCS1_PADDING_SIZE) {
199     /* |from| is zero-padded to the size of the RSA modulus, a public value, so
200      * this can be rejected in non-constant time. */
201     OPENSSL_PUT_ERROR(RSA, RSA_R_KEY_SIZE_TOO_SMALL);
202     return -1;
203   }
204 
205   unsigned first_byte_is_zero = constant_time_eq(from[0], 0);
206   unsigned second_byte_is_two = constant_time_eq(from[1], 2);
207 
208   unsigned i, zero_index = 0, looking_for_index = ~0u;
209   for (i = 2; i < from_len; i++) {
210     unsigned equals0 = constant_time_is_zero(from[i]);
211     zero_index = constant_time_select(looking_for_index & equals0, (unsigned)i,
212                                       zero_index);
213     looking_for_index = constant_time_select(equals0, 0, looking_for_index);
214   }
215 
216   /* The input must begin with 00 02. */
217   unsigned valid_index = first_byte_is_zero;
218   valid_index &= second_byte_is_two;
219 
220   /* We must have found the end of PS. */
221   valid_index &= ~looking_for_index;
222 
223   /* PS must be at least 8 bytes long, and it starts two bytes into |from|. */
224   valid_index &= constant_time_ge(zero_index, 2 + 8);
225 
226   /* Skip the zero byte. */
227   zero_index++;
228 
229   /* NOTE: Although this logic attempts to be constant time, the API contracts
230    * of this function and |RSA_decrypt| with |RSA_PKCS1_PADDING| make it
231    * impossible to completely avoid Bleichenbacher's attack. Consumers should
232    * use |RSA_unpad_key_pkcs1|. */
233   if (!valid_index) {
234     OPENSSL_PUT_ERROR(RSA, RSA_R_PKCS_DECODING_ERROR);
235     return -1;
236   }
237 
238   const unsigned msg_len = from_len - zero_index;
239   if (msg_len > to_len) {
240     /* This shouldn't happen because this function is always called with
241      * |to_len| as the key size and |from_len| is bounded by the key size. */
242     OPENSSL_PUT_ERROR(RSA, RSA_R_PKCS_DECODING_ERROR);
243     return -1;
244   }
245 
246   if (msg_len > INT_MAX) {
247     OPENSSL_PUT_ERROR(RSA, ERR_R_OVERFLOW);
248     return -1;
249   }
250 
251   OPENSSL_memcpy(to, &from[zero_index], msg_len);
252   return (int)msg_len;
253 }
254 
RSA_padding_add_none(uint8_t * to,size_t to_len,const uint8_t * from,size_t from_len)255 int RSA_padding_add_none(uint8_t *to, size_t to_len, const uint8_t *from,
256                          size_t from_len) {
257   if (from_len > to_len) {
258     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
259     return 0;
260   }
261 
262   if (from_len < to_len) {
263     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_SMALL_FOR_KEY_SIZE);
264     return 0;
265   }
266 
267   OPENSSL_memcpy(to, from, from_len);
268   return 1;
269 }
270 
PKCS1_MGF1(uint8_t * out,size_t len,const uint8_t * seed,size_t seed_len,const EVP_MD * md)271 static int PKCS1_MGF1(uint8_t *out, size_t len, const uint8_t *seed,
272                       size_t seed_len, const EVP_MD *md) {
273   int ret = 0;
274   EVP_MD_CTX ctx;
275   EVP_MD_CTX_init(&ctx);
276 
277   size_t md_len = EVP_MD_size(md);
278 
279   for (uint32_t i = 0; len > 0; i++) {
280     uint8_t counter[4];
281     counter[0] = (uint8_t)(i >> 24);
282     counter[1] = (uint8_t)(i >> 16);
283     counter[2] = (uint8_t)(i >> 8);
284     counter[3] = (uint8_t)i;
285     if (!EVP_DigestInit_ex(&ctx, md, NULL) ||
286         !EVP_DigestUpdate(&ctx, seed, seed_len) ||
287         !EVP_DigestUpdate(&ctx, counter, sizeof(counter))) {
288       goto err;
289     }
290 
291     if (md_len <= len) {
292       if (!EVP_DigestFinal_ex(&ctx, out, NULL)) {
293         goto err;
294       }
295       out += md_len;
296       len -= md_len;
297     } else {
298       uint8_t digest[EVP_MAX_MD_SIZE];
299       if (!EVP_DigestFinal_ex(&ctx, digest, NULL)) {
300         goto err;
301       }
302       OPENSSL_memcpy(out, digest, len);
303       len = 0;
304     }
305   }
306 
307   ret = 1;
308 
309 err:
310   EVP_MD_CTX_cleanup(&ctx);
311   return ret;
312 }
313 
RSA_padding_add_PKCS1_OAEP_mgf1(uint8_t * to,size_t to_len,const uint8_t * from,size_t from_len,const uint8_t * param,size_t param_len,const EVP_MD * md,const EVP_MD * mgf1md)314 int RSA_padding_add_PKCS1_OAEP_mgf1(uint8_t *to, size_t to_len,
315                                     const uint8_t *from, size_t from_len,
316                                     const uint8_t *param, size_t param_len,
317                                     const EVP_MD *md, const EVP_MD *mgf1md) {
318   if (md == NULL) {
319     md = EVP_sha1();
320   }
321   if (mgf1md == NULL) {
322     mgf1md = md;
323   }
324 
325   size_t mdlen = EVP_MD_size(md);
326 
327   if (to_len < 2 * mdlen + 2) {
328     OPENSSL_PUT_ERROR(RSA, RSA_R_KEY_SIZE_TOO_SMALL);
329     return 0;
330   }
331 
332   size_t emlen = to_len - 1;
333   if (from_len > emlen - 2 * mdlen - 1) {
334     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
335     return 0;
336   }
337 
338   if (emlen < 2 * mdlen + 1) {
339     OPENSSL_PUT_ERROR(RSA, RSA_R_KEY_SIZE_TOO_SMALL);
340     return 0;
341   }
342 
343   to[0] = 0;
344   uint8_t *seed = to + 1;
345   uint8_t *db = to + mdlen + 1;
346 
347   if (!EVP_Digest(param, param_len, db, NULL, md, NULL)) {
348     return 0;
349   }
350   OPENSSL_memset(db + mdlen, 0, emlen - from_len - 2 * mdlen - 1);
351   db[emlen - from_len - mdlen - 1] = 0x01;
352   OPENSSL_memcpy(db + emlen - from_len - mdlen, from, from_len);
353   if (!RAND_bytes(seed, mdlen)) {
354     return 0;
355   }
356 
357   uint8_t *dbmask = OPENSSL_malloc(emlen - mdlen);
358   if (dbmask == NULL) {
359     OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
360     return 0;
361   }
362 
363   int ret = 0;
364   if (!PKCS1_MGF1(dbmask, emlen - mdlen, seed, mdlen, mgf1md)) {
365     goto out;
366   }
367   for (size_t i = 0; i < emlen - mdlen; i++) {
368     db[i] ^= dbmask[i];
369   }
370 
371   uint8_t seedmask[EVP_MAX_MD_SIZE];
372   if (!PKCS1_MGF1(seedmask, mdlen, db, emlen - mdlen, mgf1md)) {
373     goto out;
374   }
375   for (size_t i = 0; i < mdlen; i++) {
376     seed[i] ^= seedmask[i];
377   }
378   ret = 1;
379 
380 out:
381   OPENSSL_free(dbmask);
382   return ret;
383 }
384 
RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t * to,unsigned to_len,const uint8_t * from,unsigned from_len,const uint8_t * param,unsigned param_len,const EVP_MD * md,const EVP_MD * mgf1md)385 int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned to_len,
386                                       const uint8_t *from, unsigned from_len,
387                                       const uint8_t *param, unsigned param_len,
388                                       const EVP_MD *md, const EVP_MD *mgf1md) {
389   unsigned i, dblen, mlen = -1, mdlen, bad, looking_for_one_byte, one_index = 0;
390   const uint8_t *maskeddb, *maskedseed;
391   uint8_t *db = NULL, seed[EVP_MAX_MD_SIZE], phash[EVP_MAX_MD_SIZE];
392 
393   if (md == NULL) {
394     md = EVP_sha1();
395   }
396   if (mgf1md == NULL) {
397     mgf1md = md;
398   }
399 
400   mdlen = EVP_MD_size(md);
401 
402   /* The encoded message is one byte smaller than the modulus to ensure that it
403    * doesn't end up greater than the modulus. Thus there's an extra "+1" here
404    * compared to https://tools.ietf.org/html/rfc2437#section-9.1.1.2. */
405   if (from_len < 1 + 2*mdlen + 1) {
406     /* 'from_len' is the length of the modulus, i.e. does not depend on the
407      * particular ciphertext. */
408     goto decoding_err;
409   }
410 
411   dblen = from_len - mdlen - 1;
412   db = OPENSSL_malloc(dblen);
413   if (db == NULL) {
414     OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
415     goto err;
416   }
417 
418   maskedseed = from + 1;
419   maskeddb = from + 1 + mdlen;
420 
421   if (!PKCS1_MGF1(seed, mdlen, maskeddb, dblen, mgf1md)) {
422     goto err;
423   }
424   for (i = 0; i < mdlen; i++) {
425     seed[i] ^= maskedseed[i];
426   }
427 
428   if (!PKCS1_MGF1(db, dblen, seed, mdlen, mgf1md)) {
429     goto err;
430   }
431   for (i = 0; i < dblen; i++) {
432     db[i] ^= maskeddb[i];
433   }
434 
435   if (!EVP_Digest(param, param_len, phash, NULL, md, NULL)) {
436     goto err;
437   }
438 
439   bad = ~constant_time_is_zero(CRYPTO_memcmp(db, phash, mdlen));
440   bad |= ~constant_time_is_zero(from[0]);
441 
442   looking_for_one_byte = ~0u;
443   for (i = mdlen; i < dblen; i++) {
444     unsigned equals1 = constant_time_eq(db[i], 1);
445     unsigned equals0 = constant_time_eq(db[i], 0);
446     one_index = constant_time_select(looking_for_one_byte & equals1, i,
447                                      one_index);
448     looking_for_one_byte =
449         constant_time_select(equals1, 0, looking_for_one_byte);
450     bad |= looking_for_one_byte & ~equals0;
451   }
452 
453   bad |= looking_for_one_byte;
454 
455   if (bad) {
456     goto decoding_err;
457   }
458 
459   one_index++;
460   mlen = dblen - one_index;
461   if (to_len < mlen) {
462     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE);
463     mlen = -1;
464   } else {
465     OPENSSL_memcpy(to, db + one_index, mlen);
466   }
467 
468   OPENSSL_free(db);
469   return mlen;
470 
471 decoding_err:
472   /* to avoid chosen ciphertext attacks, the error message should not reveal
473    * which kind of decoding error happened */
474   OPENSSL_PUT_ERROR(RSA, RSA_R_OAEP_DECODING_ERROR);
475  err:
476   OPENSSL_free(db);
477   return -1;
478 }
479 
480 static const uint8_t kPSSZeroes[] = {0, 0, 0, 0, 0, 0, 0, 0};
481 
RSA_verify_PKCS1_PSS_mgf1(RSA * rsa,const uint8_t * mHash,const EVP_MD * Hash,const EVP_MD * mgf1Hash,const uint8_t * EM,int sLen)482 int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const uint8_t *mHash,
483                               const EVP_MD *Hash, const EVP_MD *mgf1Hash,
484                               const uint8_t *EM, int sLen) {
485   int i;
486   int ret = 0;
487   int maskedDBLen, MSBits, emLen;
488   size_t hLen;
489   const uint8_t *H;
490   uint8_t *DB = NULL;
491   EVP_MD_CTX ctx;
492   uint8_t H_[EVP_MAX_MD_SIZE];
493   EVP_MD_CTX_init(&ctx);
494 
495   if (mgf1Hash == NULL) {
496     mgf1Hash = Hash;
497   }
498 
499   hLen = EVP_MD_size(Hash);
500 
501   /* Negative sLen has special meanings:
502    *	-1	sLen == hLen
503    *	-2	salt length is autorecovered from signature
504    *	-N	reserved */
505   if (sLen == -1) {
506     sLen = hLen;
507   } else if (sLen == -2) {
508     sLen = -2;
509   } else if (sLen < -2) {
510     OPENSSL_PUT_ERROR(RSA, RSA_R_SLEN_CHECK_FAILED);
511     goto err;
512   }
513 
514   MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
515   emLen = RSA_size(rsa);
516   if (EM[0] & (0xFF << MSBits)) {
517     OPENSSL_PUT_ERROR(RSA, RSA_R_FIRST_OCTET_INVALID);
518     goto err;
519   }
520   if (MSBits == 0) {
521     EM++;
522     emLen--;
523   }
524   if (emLen < (int)hLen + 2 || emLen < ((int)hLen + sLen + 2)) {
525     /* sLen can be small negative */
526     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE);
527     goto err;
528   }
529   if (EM[emLen - 1] != 0xbc) {
530     OPENSSL_PUT_ERROR(RSA, RSA_R_LAST_OCTET_INVALID);
531     goto err;
532   }
533   maskedDBLen = emLen - hLen - 1;
534   H = EM + maskedDBLen;
535   DB = OPENSSL_malloc(maskedDBLen);
536   if (!DB) {
537     OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
538     goto err;
539   }
540   if (!PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash)) {
541     goto err;
542   }
543   for (i = 0; i < maskedDBLen; i++) {
544     DB[i] ^= EM[i];
545   }
546   if (MSBits) {
547     DB[0] &= 0xFF >> (8 - MSBits);
548   }
549   for (i = 0; DB[i] == 0 && i < (maskedDBLen - 1); i++) {
550     ;
551   }
552   if (DB[i++] != 0x1) {
553     OPENSSL_PUT_ERROR(RSA, RSA_R_SLEN_RECOVERY_FAILED);
554     goto err;
555   }
556   if (sLen >= 0 && (maskedDBLen - i) != sLen) {
557     OPENSSL_PUT_ERROR(RSA, RSA_R_SLEN_CHECK_FAILED);
558     goto err;
559   }
560   if (!EVP_DigestInit_ex(&ctx, Hash, NULL) ||
561       !EVP_DigestUpdate(&ctx, kPSSZeroes, sizeof(kPSSZeroes)) ||
562       !EVP_DigestUpdate(&ctx, mHash, hLen) ||
563       !EVP_DigestUpdate(&ctx, DB + i, maskedDBLen - i) ||
564       !EVP_DigestFinal_ex(&ctx, H_, NULL)) {
565     goto err;
566   }
567   if (OPENSSL_memcmp(H_, H, hLen)) {
568     OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_SIGNATURE);
569     ret = 0;
570   } else {
571     ret = 1;
572   }
573 
574 err:
575   OPENSSL_free(DB);
576   EVP_MD_CTX_cleanup(&ctx);
577 
578   return ret;
579 }
580 
RSA_padding_add_PKCS1_PSS_mgf1(RSA * rsa,unsigned char * EM,const unsigned char * mHash,const EVP_MD * Hash,const EVP_MD * mgf1Hash,int sLenRequested)581 int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, unsigned char *EM,
582                                    const unsigned char *mHash,
583                                    const EVP_MD *Hash, const EVP_MD *mgf1Hash,
584                                    int sLenRequested) {
585   int ret = 0;
586   size_t maskedDBLen, MSBits, emLen;
587   size_t hLen;
588   unsigned char *H, *salt = NULL, *p;
589 
590   if (mgf1Hash == NULL) {
591     mgf1Hash = Hash;
592   }
593 
594   hLen = EVP_MD_size(Hash);
595 
596   if (BN_is_zero(rsa->n)) {
597     OPENSSL_PUT_ERROR(RSA, RSA_R_EMPTY_PUBLIC_KEY);
598     goto err;
599   }
600 
601   MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
602   emLen = RSA_size(rsa);
603   if (MSBits == 0) {
604     assert(emLen >= 1);
605     *EM++ = 0;
606     emLen--;
607   }
608 
609   if (emLen < hLen + 2) {
610     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
611     goto err;
612   }
613 
614   /* Negative sLenRequested has special meanings:
615    *   -1  sLen == hLen
616    *   -2  salt length is maximized
617    *   -N  reserved */
618   size_t sLen;
619   if (sLenRequested == -1) {
620     sLen = hLen;
621   } else if (sLenRequested == -2) {
622     sLen = emLen - hLen - 2;
623   } else if (sLenRequested < 0) {
624     OPENSSL_PUT_ERROR(RSA, RSA_R_SLEN_CHECK_FAILED);
625     goto err;
626   } else {
627     sLen = (size_t)sLenRequested;
628   }
629 
630   if (emLen - hLen - 2 < sLen) {
631     OPENSSL_PUT_ERROR(RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
632     goto err;
633   }
634 
635   if (sLen > 0) {
636     salt = OPENSSL_malloc(sLen);
637     if (!salt) {
638       OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
639       goto err;
640     }
641     if (!RAND_bytes(salt, sLen)) {
642       goto err;
643     }
644   }
645   maskedDBLen = emLen - hLen - 1;
646   H = EM + maskedDBLen;
647 
648   EVP_MD_CTX ctx;
649   EVP_MD_CTX_init(&ctx);
650   int digest_ok = EVP_DigestInit_ex(&ctx, Hash, NULL) &&
651                   EVP_DigestUpdate(&ctx, kPSSZeroes, sizeof(kPSSZeroes)) &&
652                   EVP_DigestUpdate(&ctx, mHash, hLen) &&
653                   EVP_DigestUpdate(&ctx, salt, sLen) &&
654                   EVP_DigestFinal_ex(&ctx, H, NULL);
655   EVP_MD_CTX_cleanup(&ctx);
656   if (!digest_ok) {
657     goto err;
658   }
659 
660   /* Generate dbMask in place then perform XOR on it */
661   if (!PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash)) {
662     goto err;
663   }
664 
665   p = EM;
666 
667   /* Initial PS XORs with all zeroes which is a NOP so just update
668    * pointer. Note from a test above this value is guaranteed to
669    * be non-negative. */
670   p += emLen - sLen - hLen - 2;
671   *p++ ^= 0x1;
672   if (sLen > 0) {
673     for (size_t i = 0; i < sLen; i++) {
674       *p++ ^= salt[i];
675     }
676   }
677   if (MSBits) {
678     EM[0] &= 0xFF >> (8 - MSBits);
679   }
680 
681   /* H is already in place so just set final 0xbc */
682 
683   EM[emLen - 1] = 0xbc;
684 
685   ret = 1;
686 
687 err:
688   OPENSSL_free(salt);
689 
690   return ret;
691 }
692