1 /*------------------------------------------------------------------------
2 / OCB Version 3 Reference Code (Optimized C) Last modified 12-JUN-2013
3 /-------------------------------------------------------------------------
4 / Copyright (c) 2013 Ted Krovetz.
5 /
6 / Permission to use, copy, modify, and/or distribute this software for any
7 / purpose with or without fee is hereby granted, provided that the above
8 / copyright notice and this permission notice appear in all copies.
9 /
10 / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 /
18 / Phillip Rogaway holds patents relevant to OCB. See the following for
19 / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20 /
21 / Special thanks to Keegan McAllister for suggesting several good improvements
22 /
23 / Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24 /------------------------------------------------------------------------- */
25
26 /* ----------------------------------------------------------------------- */
27 /* Usage notes */
28 /* ----------------------------------------------------------------------- */
29
30 /* - When AE_PENDING is passed as the 'final' parameter of any function,
31 / the length parameters must be a multiple of (BPI*16).
32 / - When available, SSE or AltiVec registers are used to manipulate data.
33 / So, when on machines with these facilities, all pointers passed to
34 / any function should be 16-byte aligned.
35 / - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36 / encrypted in-place), but no other pair of pointers may be equal.
37 / - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38 / when compiling under MSVC. If untrue, alter the #define.
39 / - This code is tested for C99 and recent versions of GCC and MSVC. */
40
41 /* ----------------------------------------------------------------------- */
42 /* User configuration options */
43 /* ----------------------------------------------------------------------- */
44
45 /* Set the AES key length to use and length of authentication tag to produce.
46 / Setting either to 0 requires the value be set at runtime via ae_init().
47 / Some optimizations occur for each when set to a fixed value. */
48 #define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49 #define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init */
50
51 /* This implementation has built-in support for multiple AES APIs. Set any
52 / one of the following to non-zero to specify which to use. */
53 #define USE_OPENSSL_AES 1 /* http://openssl.org */
54 #define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c */
55 #define USE_AES_NI 0 /* Uses compiler's intrinsics */
56
57 /* During encryption and decryption, various "L values" are required.
58 / The L values can be precomputed during initialization (requiring extra
59 / space in ae_ctx), generated as needed (slightly slowing encryption and
60 / decryption), or some combination of the two. L_TABLE_SZ specifies how many
61 / L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62 / are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63 / 2^L_TABLE_SZ blocks need no L values calculated dynamically. */
64 #define L_TABLE_SZ 16
65
66 /* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67 / will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68 / in better performance. */
69 #define L_TABLE_SZ_IS_ENOUGH 1
70
71 /* ----------------------------------------------------------------------- */
72 /* Includes and compiler specific definitions */
73 /* ----------------------------------------------------------------------- */
74
75 #include "ae.h"
76 #include <stdlib.h>
77 #include <string.h>
78
79 /* Define standard sized integers */
80 #if defined(_MSC_VER) && (_MSC_VER < 1600)
81 typedef unsigned __int8 uint8_t;
82 typedef unsigned __int32 uint32_t;
83 typedef unsigned __int64 uint64_t;
84 typedef __int64 int64_t;
85 #else
86 #include <stdint.h>
87 #endif
88
89 /* Compiler-specific intrinsics and fixes: bswap64, ntz */
90 #if _MSC_VER
91 #define inline __inline /* MSVC doesn't recognize "inline" in C */
92 #define restrict __restrict /* MSVC doesn't recognize "restrict" in C */
93 #define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSE2 */
94 #define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95 #include <intrin.h>
96 #pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97 #define bswap64(x) _byteswap_uint64(x)
ntz(unsigned x)98 static inline unsigned ntz(unsigned x) {
99 _BitScanForward(&x, x);
100 return x;
101 }
102 #elif __GNUC__
103 #define inline __inline__ /* No "inline" in GCC ansi C mode */
104 #define restrict __restrict__ /* No "restrict" in GCC ansi C mode */
105 #define bswap64(x) __builtin_bswap64(x) /* Assuming GCC 4.3+ */
106 #define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107 #else /* Assume some C99 features: stdint.h, inline, restrict */
108 #define bswap32(x) \
109 ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) | \
110 (((x)&0x000000ffu) << 24))
111
bswap64(uint64_t x)112 static inline uint64_t bswap64(uint64_t x) {
113 union {
114 uint64_t u64;
115 uint32_t u32[2];
116 } in, out;
117 in.u64 = x;
118 out.u32[0] = bswap32(in.u32[1]);
119 out.u32[1] = bswap32(in.u32[0]);
120 return out.u64;
121 }
122
123 #if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
ntz(unsigned x)124 static inline unsigned ntz(unsigned x) {
125 static const unsigned char tz_table[] = {
126 0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131 return tz_table[x / 4];
132 }
133 #else /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
ntz(unsigned x)134 static inline unsigned ntz(unsigned x) {
135 static const unsigned char tz_table[32] = {0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20,
136 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19,
137 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
138 return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139 }
140 #endif
141 #endif
142
143 /* ----------------------------------------------------------------------- */
144 /* Define blocks and operations -- Patch if incorrect on your compiler. */
145 /* ----------------------------------------------------------------------- */
146
147 #if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148 #include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149 #include <emmintrin.h> /* SSE2 instructions */
150 typedef __m128i block;
151 #define xor_block(x, y) _mm_xor_si128(x, y)
152 #define zero_block() _mm_setzero_si128()
153 #define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154 #if __SSSE3__ || USE_AES_NI
155 #include <tmmintrin.h> /* SSSE3 instructions */
156 #define swap_if_le(b) \
157 _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158 #else
swap_if_le(block b)159 static inline block swap_if_le(block b) {
160 block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161 a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162 a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163 return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164 }
165 #endif
gen_offset(uint64_t KtopStr[3],unsigned bot)166 static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167 block hi = _mm_load_si128((__m128i*)(KtopStr + 0)); /* hi = B A */
168 block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169 __m128i lshift = _mm_cvtsi32_si128(bot);
170 __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171 lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172 #if __SSSE3__ || USE_AES_NI
173 return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174 #else
175 return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176 #endif
177 }
double_block(block bl)178 static inline block double_block(block bl) {
179 const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180 __m128i tmp = _mm_srai_epi32(bl, 31);
181 tmp = _mm_and_si128(tmp, mask);
182 tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183 bl = _mm_slli_epi32(bl, 1);
184 return _mm_xor_si128(bl, tmp);
185 }
186 #elif __ALTIVEC__
187 #include <altivec.h>
188 typedef vector unsigned block;
189 #define xor_block(x, y) vec_xor(x, y)
190 #define zero_block() vec_splat_u32(0)
191 #define unequal_blocks(x, y) vec_any_ne(x, y)
192 #define swap_if_le(b) (b)
193 #if __PPC64__
gen_offset(uint64_t KtopStr[3],unsigned bot)194 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195 union {
196 uint64_t u64[2];
197 block bl;
198 } rval;
199 rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200 rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201 return rval.bl;
202 }
203 #else
204 /* Special handling: Shifts are mod 32, and no 64-bit types */
gen_offset(uint64_t KtopStr[3],unsigned bot)205 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206 const vector unsigned k32 = {32, 32, 32, 32};
207 vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208 vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209 vector unsigned bot_vec;
210 if (bot < 32) {
211 lo = vec_sld(hi, lo, 4);
212 } else {
213 vector unsigned t = vec_sld(hi, lo, 4);
214 lo = vec_sld(hi, lo, 8);
215 hi = t;
216 bot = bot - 32;
217 }
218 if (bot == 0)
219 return hi;
220 *(unsigned*)&bot_vec = bot;
221 vector unsigned lshift = vec_splat(bot_vec, 0);
222 vector unsigned rshift = vec_sub(k32, lshift);
223 hi = vec_sl(hi, lshift);
224 lo = vec_sr(lo, rshift);
225 return vec_xor(hi, lo);
226 }
227 #endif
double_block(block b)228 static inline block double_block(block b) {
229 const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230 const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231 const vector unsigned char shift7 = vec_splat_u8(7);
232 const vector unsigned char shift1 = vec_splat_u8(1);
233 vector unsigned char c = (vector unsigned char)b;
234 vector unsigned char t = vec_sra(c, shift7);
235 t = vec_and(t, mask);
236 t = vec_perm(t, t, perm);
237 c = vec_sl(c, shift1);
238 return (block)vec_xor(c, t);
239 }
240 #elif __ARM_NEON__
241 #include <arm_neon.h>
242 typedef int8x16_t block; /* Yay! Endian-neutral reads! */
243 #define xor_block(x, y) veorq_s8(x, y)
244 #define zero_block() vdupq_n_s8(0)
unequal_blocks(block a,block b)245 static inline int unequal_blocks(block a, block b) {
246 int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247 return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248 }
249 #define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)251 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252 const union {
253 unsigned x;
254 unsigned char endian;
255 } little = {1};
256 const int64x2_t k64 = {-64, -64};
257 uint64x2_t hi = *(uint64x2_t*)(KtopStr + 0); /* hi = A B */
258 uint64x2_t lo = *(uint64x2_t*)(KtopStr + 1); /* hi = B C */
259 int64x2_t ls = vdupq_n_s64(bot);
260 int64x2_t rs = vqaddq_s64(k64, ls);
261 block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
262 if (little.endian)
263 rval = vrev64q_s8(rval);
264 return rval;
265 }
double_block(block b)266 static inline block double_block(block b) {
267 const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
268 block tmp = vshrq_n_s8(b, 7);
269 tmp = vandq_s8(tmp, mask);
270 tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
271 b = vshlq_n_s8(b, 1);
272 return veorq_s8(tmp, b);
273 }
274 #else
275 typedef struct { uint64_t l, r; } block;
xor_block(block x,block y)276 static inline block xor_block(block x, block y) {
277 x.l ^= y.l;
278 x.r ^= y.r;
279 return x;
280 }
zero_block(void)281 static inline block zero_block(void) {
282 const block t = {0, 0};
283 return t;
284 }
285 #define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
swap_if_le(block b)286 static inline block swap_if_le(block b) {
287 const union {
288 unsigned x;
289 unsigned char endian;
290 } little = {1};
291 if (little.endian) {
292 block r;
293 r.l = bswap64(b.l);
294 r.r = bswap64(b.r);
295 return r;
296 } else
297 return b;
298 }
299
300 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)301 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
302 block rval;
303 if (bot != 0) {
304 rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
305 rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
306 } else {
307 rval.l = KtopStr[0];
308 rval.r = KtopStr[1];
309 }
310 return swap_if_le(rval);
311 }
312
313 #if __GNUC__ && __arm__
double_block(block b)314 static inline block double_block(block b) {
315 __asm__("adds %1,%1,%1\n\t"
316 "adcs %H1,%H1,%H1\n\t"
317 "adcs %0,%0,%0\n\t"
318 "adcs %H0,%H0,%H0\n\t"
319 "it cs\n\t"
320 "eorcs %1,%1,#135"
321 : "+r"(b.l), "+r"(b.r)
322 :
323 : "cc");
324 return b;
325 }
326 #else
double_block(block b)327 static inline block double_block(block b) {
328 uint64_t t = (uint64_t)((int64_t)b.l >> 63);
329 b.l = (b.l + b.l) ^ (b.r >> 63);
330 b.r = (b.r + b.r) ^ (t & 135);
331 return b;
332 }
333 #endif
334
335 #endif
336
337 /* ----------------------------------------------------------------------- */
338 /* AES - Code uses OpenSSL API. Other implementations get mapped to it. */
339 /* ----------------------------------------------------------------------- */
340
341 /*---------------*/
342 #if USE_OPENSSL_AES
343 /*---------------*/
344
345 #include <openssl/aes.h> /* http://openssl.org/ */
346
347 /* How to ECB encrypt an array of blocks, in place */
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)348 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
349 while (nblks) {
350 --nblks;
351 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
352 }
353 }
354
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)355 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
356 while (nblks) {
357 --nblks;
358 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
359 }
360 }
361
362 #define BPI 4 /* Number of blocks in buffer per ECB call */
363
364 /*-------------------*/
365 #elif USE_REFERENCE_AES
366 /*-------------------*/
367
368 #include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
369 #if (OCB_KEY_LEN == 0)
370 typedef struct {
371 uint32_t rd_key[60];
372 int rounds;
373 } AES_KEY;
374 #define ROUNDS(ctx) ((ctx)->rounds)
375 #define AES_set_encrypt_key(x, y, z) \
376 do { \
377 rijndaelKeySetupEnc((z)->rd_key, x, y); \
378 (z)->rounds = y / 32 + 6; \
379 } while (0)
380 #define AES_set_decrypt_key(x, y, z) \
381 do { \
382 rijndaelKeySetupDec((z)->rd_key, x, y); \
383 (z)->rounds = y / 32 + 6; \
384 } while (0)
385 #else
386 typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
387 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
388 #define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
389 #define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
390 #endif
391 #define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
392 #define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
393
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)394 static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
395 while (nblks) {
396 --nblks;
397 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
398 }
399 }
400
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)401 void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
402 while (nblks) {
403 --nblks;
404 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
405 }
406 }
407
408 #define BPI 4 /* Number of blocks in buffer per ECB call */
409
410 /*----------*/
411 #elif USE_AES_NI
412 /*----------*/
413
414 #include <wmmintrin.h>
415
416 #if (OCB_KEY_LEN == 0)
417 typedef struct {
418 __m128i rd_key[15];
419 int rounds;
420 } AES_KEY;
421 #define ROUNDS(ctx) ((ctx)->rounds)
422 #else
423 typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
424 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
425 #endif
426
427 #define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const) \
428 v2 = _mm_aeskeygenassist_si128(v4, aes_const); \
429 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16)); \
430 v1 = _mm_xor_si128(v1, v3); \
431 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140)); \
432 v1 = _mm_xor_si128(v1, v3); \
433 v2 = _mm_shuffle_epi32(v2, shuff_const); \
434 v1 = _mm_xor_si128(v1, v2)
435
436 #define EXPAND192_STEP(idx, aes_const) \
437 EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const); \
438 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
439 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
440 kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68)); \
441 kp[idx + 1] = \
442 _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78)); \
443 EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2)); \
444 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
445 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
446 kp[idx + 2] = x0; \
447 tmp = x3
448
AES_128_Key_Expansion(const unsigned char * userkey,void * key)449 static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
450 __m128i x0, x1, x2;
451 __m128i* kp = (__m128i*)key;
452 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
453 x2 = _mm_setzero_si128();
454 EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
455 kp[1] = x0;
456 EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
457 kp[2] = x0;
458 EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
459 kp[3] = x0;
460 EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
461 kp[4] = x0;
462 EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
463 kp[5] = x0;
464 EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
465 kp[6] = x0;
466 EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
467 kp[7] = x0;
468 EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
469 kp[8] = x0;
470 EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
471 kp[9] = x0;
472 EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
473 kp[10] = x0;
474 }
475
AES_192_Key_Expansion(const unsigned char * userkey,void * key)476 static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
477 __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
478 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
479 tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
480 x2 = _mm_setzero_si128();
481 EXPAND192_STEP(1, 1);
482 EXPAND192_STEP(4, 4);
483 EXPAND192_STEP(7, 16);
484 EXPAND192_STEP(10, 64);
485 }
486
AES_256_Key_Expansion(const unsigned char * userkey,void * key)487 static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
488 __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
489 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
490 kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
491 x2 = _mm_setzero_si128();
492 EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
493 kp[2] = x0;
494 EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
495 kp[3] = x3;
496 EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
497 kp[4] = x0;
498 EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
499 kp[5] = x3;
500 EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
501 kp[6] = x0;
502 EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
503 kp[7] = x3;
504 EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
505 kp[8] = x0;
506 EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
507 kp[9] = x3;
508 EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
509 kp[10] = x0;
510 EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
511 kp[11] = x3;
512 EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
513 kp[12] = x0;
514 EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
515 kp[13] = x3;
516 EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
517 kp[14] = x0;
518 }
519
AES_set_encrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)520 static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
521 if (bits == 128) {
522 AES_128_Key_Expansion(userKey, key);
523 } else if (bits == 192) {
524 AES_192_Key_Expansion(userKey, key);
525 } else if (bits == 256) {
526 AES_256_Key_Expansion(userKey, key);
527 }
528 #if (OCB_KEY_LEN == 0)
529 key->rounds = 6 + bits / 32;
530 #endif
531 return 0;
532 }
533
AES_set_decrypt_key_fast(AES_KEY * dkey,const AES_KEY * ekey)534 static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
535 int j = 0;
536 int i = ROUNDS(ekey);
537 #if (OCB_KEY_LEN == 0)
538 dkey->rounds = i;
539 #endif
540 dkey->rd_key[i--] = ekey->rd_key[j++];
541 while (i)
542 dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
543 dkey->rd_key[i] = ekey->rd_key[j];
544 }
545
AES_set_decrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)546 static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
547 AES_KEY temp_key;
548 AES_set_encrypt_key(userKey, bits, &temp_key);
549 AES_set_decrypt_key_fast(key, &temp_key);
550 return 0;
551 }
552
AES_encrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)553 static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
554 int j, rnds = ROUNDS(key);
555 const __m128i* sched = ((__m128i*)(key->rd_key));
556 __m128i tmp = _mm_load_si128((__m128i*)in);
557 tmp = _mm_xor_si128(tmp, sched[0]);
558 for (j = 1; j < rnds; j++)
559 tmp = _mm_aesenc_si128(tmp, sched[j]);
560 tmp = _mm_aesenclast_si128(tmp, sched[j]);
561 _mm_store_si128((__m128i*)out, tmp);
562 }
563
AES_decrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)564 static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
565 int j, rnds = ROUNDS(key);
566 const __m128i* sched = ((__m128i*)(key->rd_key));
567 __m128i tmp = _mm_load_si128((__m128i*)in);
568 tmp = _mm_xor_si128(tmp, sched[0]);
569 for (j = 1; j < rnds; j++)
570 tmp = _mm_aesdec_si128(tmp, sched[j]);
571 tmp = _mm_aesdeclast_si128(tmp, sched[j]);
572 _mm_store_si128((__m128i*)out, tmp);
573 }
574
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)575 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
576 unsigned i, j, rnds = ROUNDS(key);
577 const __m128i* sched = ((__m128i*)(key->rd_key));
578 for (i = 0; i < nblks; ++i)
579 blks[i] = _mm_xor_si128(blks[i], sched[0]);
580 for (j = 1; j < rnds; ++j)
581 for (i = 0; i < nblks; ++i)
582 blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
583 for (i = 0; i < nblks; ++i)
584 blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
585 }
586
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)587 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
588 unsigned i, j, rnds = ROUNDS(key);
589 const __m128i* sched = ((__m128i*)(key->rd_key));
590 for (i = 0; i < nblks; ++i)
591 blks[i] = _mm_xor_si128(blks[i], sched[0]);
592 for (j = 1; j < rnds; ++j)
593 for (i = 0; i < nblks; ++i)
594 blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
595 for (i = 0; i < nblks; ++i)
596 blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
597 }
598
599 #define BPI 8 /* Number of blocks in buffer per ECB call */
600 /* Set to 4 for Westmere, 8 for Sandy Bridge */
601
602 #endif
603
604 /* ----------------------------------------------------------------------- */
605 /* Define OCB context structure. */
606 /* ----------------------------------------------------------------------- */
607
608 /*------------------------------------------------------------------------
609 / Each item in the OCB context is stored either "memory correct" or
610 / "register correct". On big-endian machines, this is identical. On
611 / little-endian machines, one must choose whether the byte-string
612 / is in the correct order when it resides in memory or in registers.
613 / It must be register correct whenever it is to be manipulated
614 / arithmetically, but must be memory correct whenever it interacts
615 / with the plaintext or ciphertext.
616 /------------------------------------------------------------------------- */
617
618 struct _ae_ctx {
619 block offset; /* Memory correct */
620 block checksum; /* Memory correct */
621 block Lstar; /* Memory correct */
622 block Ldollar; /* Memory correct */
623 block L[L_TABLE_SZ]; /* Memory correct */
624 block ad_checksum; /* Memory correct */
625 block ad_offset; /* Memory correct */
626 block cached_Top; /* Memory correct */
627 uint64_t KtopStr[3]; /* Register correct, each item */
628 uint32_t ad_blocks_processed;
629 uint32_t blocks_processed;
630 AES_KEY decrypt_key;
631 AES_KEY encrypt_key;
632 #if (OCB_TAG_LEN == 0)
633 unsigned tag_len;
634 #endif
635 };
636
637 /* ----------------------------------------------------------------------- */
638 /* L table lookup (or on-the-fly generation) */
639 /* ----------------------------------------------------------------------- */
640
641 #if L_TABLE_SZ_IS_ENOUGH
642 #define getL(_ctx, _tz) ((_ctx)->L[_tz])
643 #else
getL(const ae_ctx * ctx,unsigned tz)644 static block getL(const ae_ctx* ctx, unsigned tz) {
645 if (tz < L_TABLE_SZ)
646 return ctx->L[tz];
647 else {
648 unsigned i;
649 /* Bring L[MAX] into registers, make it register correct */
650 block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
651 rval = double_block(rval);
652 for (i = L_TABLE_SZ; i < tz; i++)
653 rval = double_block(rval);
654 return swap_if_le(rval); /* To memory correct */
655 }
656 }
657 #endif
658
659 /* ----------------------------------------------------------------------- */
660 /* Public functions */
661 /* ----------------------------------------------------------------------- */
662
663 /* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
664 on 16-byte alignments. (I believe all major 64-bit systems do already.) */
665
ae_allocate(void * misc)666 ae_ctx* ae_allocate(void* misc) {
667 void* p;
668 (void)misc; /* misc unused in this implementation */
669 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
670 p = _mm_malloc(sizeof(ae_ctx), 16);
671 #elif(__ALTIVEC__ && !__PPC64__)
672 if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
673 p = NULL;
674 #else
675 p = malloc(sizeof(ae_ctx));
676 #endif
677 return (ae_ctx*)p;
678 }
679
ae_free(ae_ctx * ctx)680 void ae_free(ae_ctx* ctx) {
681 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
682 _mm_free(ctx);
683 #else
684 free(ctx);
685 #endif
686 }
687
688 /* ----------------------------------------------------------------------- */
689
ae_clear(ae_ctx * ctx)690 int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization */
691 {
692 memset(ctx, 0, sizeof(ae_ctx));
693 return AE_SUCCESS;
694 }
695
ae_ctx_sizeof(void)696 int ae_ctx_sizeof(void) {
697 return (int)sizeof(ae_ctx);
698 }
699
700 /* ----------------------------------------------------------------------- */
701
ae_init(ae_ctx * ctx,const void * key,int key_len,int nonce_len,int tag_len)702 int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
703 unsigned i;
704 block tmp_blk;
705
706 if (nonce_len != 12)
707 return AE_NOT_SUPPORTED;
708
709 /* Initialize encryption & decryption keys */
710 #if (OCB_KEY_LEN > 0)
711 key_len = OCB_KEY_LEN;
712 #endif
713 AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
714 #if USE_AES_NI
715 AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
716 #else
717 AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
718 #endif
719
720 /* Zero things that need zeroing */
721 ctx->cached_Top = ctx->ad_checksum = zero_block();
722 ctx->ad_blocks_processed = 0;
723
724 /* Compute key-dependent values */
725 AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
726 tmp_blk = swap_if_le(ctx->Lstar);
727 tmp_blk = double_block(tmp_blk);
728 ctx->Ldollar = swap_if_le(tmp_blk);
729 tmp_blk = double_block(tmp_blk);
730 ctx->L[0] = swap_if_le(tmp_blk);
731 for (i = 1; i < L_TABLE_SZ; i++) {
732 tmp_blk = double_block(tmp_blk);
733 ctx->L[i] = swap_if_le(tmp_blk);
734 }
735
736 #if (OCB_TAG_LEN == 0)
737 ctx->tag_len = tag_len;
738 #else
739 (void)tag_len; /* Suppress var not used error */
740 #endif
741
742 return AE_SUCCESS;
743 }
744
745 /* ----------------------------------------------------------------------- */
746
gen_offset_from_nonce(ae_ctx * ctx,const void * nonce)747 static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
748 const union {
749 unsigned x;
750 unsigned char endian;
751 } little = {1};
752 union {
753 uint32_t u32[4];
754 uint8_t u8[16];
755 block bl;
756 } tmp;
757 unsigned idx;
758
759 /* Replace cached nonce Top if needed */
760 #if (OCB_TAG_LEN > 0)
761 if (little.endian)
762 tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
763 else
764 tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
765 #else
766 if (little.endian)
767 tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
768 else
769 tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
770 #endif
771 tmp.u32[1] = ((uint32_t*)nonce)[0];
772 tmp.u32[2] = ((uint32_t*)nonce)[1];
773 tmp.u32[3] = ((uint32_t*)nonce)[2];
774 idx = (unsigned)(tmp.u8[15] & 0x3f); /* Get low 6 bits of nonce */
775 tmp.u8[15] = tmp.u8[15] & 0xc0; /* Zero low 6 bits of nonce */
776 if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached? */
777 ctx->cached_Top = tmp.bl; /* Update cache, KtopStr */
778 AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
779 if (little.endian) { /* Make Register Correct */
780 ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
781 ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
782 }
783 ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
784 }
785 return gen_offset(ctx->KtopStr, idx);
786 }
787
process_ad(ae_ctx * ctx,const void * ad,int ad_len,int final)788 static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
789 union {
790 uint32_t u32[4];
791 uint8_t u8[16];
792 block bl;
793 } tmp;
794 block ad_offset, ad_checksum;
795 const block* adp = (block*)ad;
796 unsigned i, k, tz, remaining;
797
798 ad_offset = ctx->ad_offset;
799 ad_checksum = ctx->ad_checksum;
800 i = ad_len / (BPI * 16);
801 if (i) {
802 unsigned ad_block_num = ctx->ad_blocks_processed;
803 do {
804 block ta[BPI], oa[BPI];
805 ad_block_num += BPI;
806 tz = ntz(ad_block_num);
807 oa[0] = xor_block(ad_offset, ctx->L[0]);
808 ta[0] = xor_block(oa[0], adp[0]);
809 oa[1] = xor_block(oa[0], ctx->L[1]);
810 ta[1] = xor_block(oa[1], adp[1]);
811 oa[2] = xor_block(ad_offset, ctx->L[1]);
812 ta[2] = xor_block(oa[2], adp[2]);
813 #if BPI == 4
814 ad_offset = xor_block(oa[2], getL(ctx, tz));
815 ta[3] = xor_block(ad_offset, adp[3]);
816 #elif BPI == 8
817 oa[3] = xor_block(oa[2], ctx->L[2]);
818 ta[3] = xor_block(oa[3], adp[3]);
819 oa[4] = xor_block(oa[1], ctx->L[2]);
820 ta[4] = xor_block(oa[4], adp[4]);
821 oa[5] = xor_block(oa[0], ctx->L[2]);
822 ta[5] = xor_block(oa[5], adp[5]);
823 oa[6] = xor_block(ad_offset, ctx->L[2]);
824 ta[6] = xor_block(oa[6], adp[6]);
825 ad_offset = xor_block(oa[6], getL(ctx, tz));
826 ta[7] = xor_block(ad_offset, adp[7]);
827 #endif
828 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
829 ad_checksum = xor_block(ad_checksum, ta[0]);
830 ad_checksum = xor_block(ad_checksum, ta[1]);
831 ad_checksum = xor_block(ad_checksum, ta[2]);
832 ad_checksum = xor_block(ad_checksum, ta[3]);
833 #if (BPI == 8)
834 ad_checksum = xor_block(ad_checksum, ta[4]);
835 ad_checksum = xor_block(ad_checksum, ta[5]);
836 ad_checksum = xor_block(ad_checksum, ta[6]);
837 ad_checksum = xor_block(ad_checksum, ta[7]);
838 #endif
839 adp += BPI;
840 } while (--i);
841 ctx->ad_blocks_processed = ad_block_num;
842 ctx->ad_offset = ad_offset;
843 ctx->ad_checksum = ad_checksum;
844 }
845
846 if (final) {
847 block ta[BPI];
848
849 /* Process remaining associated data, compute its tag contribution */
850 remaining = ((unsigned)ad_len) % (BPI * 16);
851 if (remaining) {
852 k = 0;
853 #if (BPI == 8)
854 if (remaining >= 64) {
855 tmp.bl = xor_block(ad_offset, ctx->L[0]);
856 ta[0] = xor_block(tmp.bl, adp[0]);
857 tmp.bl = xor_block(tmp.bl, ctx->L[1]);
858 ta[1] = xor_block(tmp.bl, adp[1]);
859 ad_offset = xor_block(ad_offset, ctx->L[1]);
860 ta[2] = xor_block(ad_offset, adp[2]);
861 ad_offset = xor_block(ad_offset, ctx->L[2]);
862 ta[3] = xor_block(ad_offset, adp[3]);
863 remaining -= 64;
864 k = 4;
865 }
866 #endif
867 if (remaining >= 32) {
868 ad_offset = xor_block(ad_offset, ctx->L[0]);
869 ta[k] = xor_block(ad_offset, adp[k]);
870 ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
871 ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
872 remaining -= 32;
873 k += 2;
874 }
875 if (remaining >= 16) {
876 ad_offset = xor_block(ad_offset, ctx->L[0]);
877 ta[k] = xor_block(ad_offset, adp[k]);
878 remaining = remaining - 16;
879 ++k;
880 }
881 if (remaining) {
882 ad_offset = xor_block(ad_offset, ctx->Lstar);
883 tmp.bl = zero_block();
884 memcpy(tmp.u8, adp + k, remaining);
885 tmp.u8[remaining] = (unsigned char)0x80u;
886 ta[k] = xor_block(ad_offset, tmp.bl);
887 ++k;
888 }
889 AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
890 switch (k) {
891 #if (BPI == 8)
892 case 8:
893 ad_checksum = xor_block(ad_checksum, ta[7]);
894 case 7:
895 ad_checksum = xor_block(ad_checksum, ta[6]);
896 case 6:
897 ad_checksum = xor_block(ad_checksum, ta[5]);
898 case 5:
899 ad_checksum = xor_block(ad_checksum, ta[4]);
900 #endif
901 case 4:
902 ad_checksum = xor_block(ad_checksum, ta[3]);
903 case 3:
904 ad_checksum = xor_block(ad_checksum, ta[2]);
905 case 2:
906 ad_checksum = xor_block(ad_checksum, ta[1]);
907 case 1:
908 ad_checksum = xor_block(ad_checksum, ta[0]);
909 }
910 ctx->ad_checksum = ad_checksum;
911 }
912 }
913 }
914
915 /* ----------------------------------------------------------------------- */
916
ae_encrypt(ae_ctx * ctx,const void * nonce,const void * pt,int pt_len,const void * ad,int ad_len,void * ct,void * tag,int final)917 int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
918 int ad_len, void* ct, void* tag, int final) {
919 union {
920 uint32_t u32[4];
921 uint8_t u8[16];
922 block bl;
923 } tmp;
924 block offset, checksum;
925 unsigned i, k;
926 block* ctp = (block*)ct;
927 const block* ptp = (block*)pt;
928
929 /* Non-null nonce means start of new message, init per-message values */
930 if (nonce) {
931 ctx->offset = gen_offset_from_nonce(ctx, nonce);
932 ctx->ad_offset = ctx->checksum = zero_block();
933 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
934 if (ad_len >= 0)
935 ctx->ad_checksum = zero_block();
936 }
937
938 /* Process associated data */
939 if (ad_len > 0)
940 process_ad(ctx, ad, ad_len, final);
941
942 /* Encrypt plaintext data BPI blocks at a time */
943 offset = ctx->offset;
944 checksum = ctx->checksum;
945 i = pt_len / (BPI * 16);
946 if (i) {
947 block oa[BPI];
948 unsigned block_num = ctx->blocks_processed;
949 oa[BPI - 1] = offset;
950 do {
951 block ta[BPI];
952 block_num += BPI;
953 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
954 ta[0] = xor_block(oa[0], ptp[0]);
955 checksum = xor_block(checksum, ptp[0]);
956 oa[1] = xor_block(oa[0], ctx->L[1]);
957 ta[1] = xor_block(oa[1], ptp[1]);
958 checksum = xor_block(checksum, ptp[1]);
959 oa[2] = xor_block(oa[1], ctx->L[0]);
960 ta[2] = xor_block(oa[2], ptp[2]);
961 checksum = xor_block(checksum, ptp[2]);
962 #if BPI == 4
963 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
964 ta[3] = xor_block(oa[3], ptp[3]);
965 checksum = xor_block(checksum, ptp[3]);
966 #elif BPI == 8
967 oa[3] = xor_block(oa[2], ctx->L[2]);
968 ta[3] = xor_block(oa[3], ptp[3]);
969 checksum = xor_block(checksum, ptp[3]);
970 oa[4] = xor_block(oa[1], ctx->L[2]);
971 ta[4] = xor_block(oa[4], ptp[4]);
972 checksum = xor_block(checksum, ptp[4]);
973 oa[5] = xor_block(oa[0], ctx->L[2]);
974 ta[5] = xor_block(oa[5], ptp[5]);
975 checksum = xor_block(checksum, ptp[5]);
976 oa[6] = xor_block(oa[7], ctx->L[2]);
977 ta[6] = xor_block(oa[6], ptp[6]);
978 checksum = xor_block(checksum, ptp[6]);
979 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
980 ta[7] = xor_block(oa[7], ptp[7]);
981 checksum = xor_block(checksum, ptp[7]);
982 #endif
983 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
984 ctp[0] = xor_block(ta[0], oa[0]);
985 ctp[1] = xor_block(ta[1], oa[1]);
986 ctp[2] = xor_block(ta[2], oa[2]);
987 ctp[3] = xor_block(ta[3], oa[3]);
988 #if (BPI == 8)
989 ctp[4] = xor_block(ta[4], oa[4]);
990 ctp[5] = xor_block(ta[5], oa[5]);
991 ctp[6] = xor_block(ta[6], oa[6]);
992 ctp[7] = xor_block(ta[7], oa[7]);
993 #endif
994 ptp += BPI;
995 ctp += BPI;
996 } while (--i);
997 ctx->offset = offset = oa[BPI - 1];
998 ctx->blocks_processed = block_num;
999 ctx->checksum = checksum;
1000 }
1001
1002 if (final) {
1003 block ta[BPI + 1], oa[BPI];
1004
1005 /* Process remaining plaintext and compute its tag contribution */
1006 unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1007 k = 0; /* How many blocks in ta[] need ECBing */
1008 if (remaining) {
1009 #if (BPI == 8)
1010 if (remaining >= 64) {
1011 oa[0] = xor_block(offset, ctx->L[0]);
1012 ta[0] = xor_block(oa[0], ptp[0]);
1013 checksum = xor_block(checksum, ptp[0]);
1014 oa[1] = xor_block(oa[0], ctx->L[1]);
1015 ta[1] = xor_block(oa[1], ptp[1]);
1016 checksum = xor_block(checksum, ptp[1]);
1017 oa[2] = xor_block(oa[1], ctx->L[0]);
1018 ta[2] = xor_block(oa[2], ptp[2]);
1019 checksum = xor_block(checksum, ptp[2]);
1020 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1021 ta[3] = xor_block(offset, ptp[3]);
1022 checksum = xor_block(checksum, ptp[3]);
1023 remaining -= 64;
1024 k = 4;
1025 }
1026 #endif
1027 if (remaining >= 32) {
1028 oa[k] = xor_block(offset, ctx->L[0]);
1029 ta[k] = xor_block(oa[k], ptp[k]);
1030 checksum = xor_block(checksum, ptp[k]);
1031 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1032 ta[k + 1] = xor_block(offset, ptp[k + 1]);
1033 checksum = xor_block(checksum, ptp[k + 1]);
1034 remaining -= 32;
1035 k += 2;
1036 }
1037 if (remaining >= 16) {
1038 offset = oa[k] = xor_block(offset, ctx->L[0]);
1039 ta[k] = xor_block(offset, ptp[k]);
1040 checksum = xor_block(checksum, ptp[k]);
1041 remaining -= 16;
1042 ++k;
1043 }
1044 if (remaining) {
1045 tmp.bl = zero_block();
1046 memcpy(tmp.u8, ptp + k, remaining);
1047 tmp.u8[remaining] = (unsigned char)0x80u;
1048 checksum = xor_block(checksum, tmp.bl);
1049 ta[k] = offset = xor_block(offset, ctx->Lstar);
1050 ++k;
1051 }
1052 }
1053 offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1054 ta[k] = xor_block(offset, checksum); /* Part of tag gen */
1055 AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1056 offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1057 if (remaining) {
1058 --k;
1059 tmp.bl = xor_block(tmp.bl, ta[k]);
1060 memcpy(ctp + k, tmp.u8, remaining);
1061 }
1062 switch (k) {
1063 #if (BPI == 8)
1064 case 7:
1065 ctp[6] = xor_block(ta[6], oa[6]);
1066 case 6:
1067 ctp[5] = xor_block(ta[5], oa[5]);
1068 case 5:
1069 ctp[4] = xor_block(ta[4], oa[4]);
1070 case 4:
1071 ctp[3] = xor_block(ta[3], oa[3]);
1072 #endif
1073 case 3:
1074 ctp[2] = xor_block(ta[2], oa[2]);
1075 case 2:
1076 ctp[1] = xor_block(ta[1], oa[1]);
1077 case 1:
1078 ctp[0] = xor_block(ta[0], oa[0]);
1079 }
1080
1081 /* Tag is placed at the correct location
1082 */
1083 if (tag) {
1084 #if (OCB_TAG_LEN == 16)
1085 *(block*)tag = offset;
1086 #elif(OCB_TAG_LEN > 0)
1087 memcpy((char*)tag, &offset, OCB_TAG_LEN);
1088 #else
1089 memcpy((char*)tag, &offset, ctx->tag_len);
1090 #endif
1091 } else {
1092 #if (OCB_TAG_LEN > 0)
1093 memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1094 pt_len += OCB_TAG_LEN;
1095 #else
1096 memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1097 pt_len += ctx->tag_len;
1098 #endif
1099 }
1100 }
1101 return (int)pt_len;
1102 }
1103
1104 /* ----------------------------------------------------------------------- */
1105
1106 /* Compare two regions of memory, taking a constant amount of time for a
1107 given buffer size -- under certain assumptions about the compiler
1108 and machine, of course.
1109
1110 Use this to avoid timing side-channel attacks.
1111
1112 Returns 0 for memory regions with equal contents; non-zero otherwise. */
constant_time_memcmp(const void * av,const void * bv,size_t n)1113 static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1114 const uint8_t* a = (const uint8_t*)av;
1115 const uint8_t* b = (const uint8_t*)bv;
1116 uint8_t result = 0;
1117 size_t i;
1118
1119 for (i = 0; i < n; i++) {
1120 result |= *a ^ *b;
1121 a++;
1122 b++;
1123 }
1124
1125 return (int)result;
1126 }
1127
ae_decrypt(ae_ctx * ctx,const void * nonce,const void * ct,int ct_len,const void * ad,int ad_len,void * pt,const void * tag,int final)1128 int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1129 int ad_len, void* pt, const void* tag, int final) {
1130 union {
1131 uint32_t u32[4];
1132 uint8_t u8[16];
1133 block bl;
1134 } tmp;
1135 block offset, checksum;
1136 unsigned i, k;
1137 block* ctp = (block*)ct;
1138 block* ptp = (block*)pt;
1139
1140 /* Reduce ct_len tag bundled in ct */
1141 if ((final) && (!tag))
1142 #if (OCB_TAG_LEN > 0)
1143 ct_len -= OCB_TAG_LEN;
1144 #else
1145 ct_len -= ctx->tag_len;
1146 #endif
1147
1148 /* Non-null nonce means start of new message, init per-message values */
1149 if (nonce) {
1150 ctx->offset = gen_offset_from_nonce(ctx, nonce);
1151 ctx->ad_offset = ctx->checksum = zero_block();
1152 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1153 if (ad_len >= 0)
1154 ctx->ad_checksum = zero_block();
1155 }
1156
1157 /* Process associated data */
1158 if (ad_len > 0)
1159 process_ad(ctx, ad, ad_len, final);
1160
1161 /* Encrypt plaintext data BPI blocks at a time */
1162 offset = ctx->offset;
1163 checksum = ctx->checksum;
1164 i = ct_len / (BPI * 16);
1165 if (i) {
1166 block oa[BPI];
1167 unsigned block_num = ctx->blocks_processed;
1168 oa[BPI - 1] = offset;
1169 do {
1170 block ta[BPI];
1171 block_num += BPI;
1172 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1173 ta[0] = xor_block(oa[0], ctp[0]);
1174 oa[1] = xor_block(oa[0], ctx->L[1]);
1175 ta[1] = xor_block(oa[1], ctp[1]);
1176 oa[2] = xor_block(oa[1], ctx->L[0]);
1177 ta[2] = xor_block(oa[2], ctp[2]);
1178 #if BPI == 4
1179 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1180 ta[3] = xor_block(oa[3], ctp[3]);
1181 #elif BPI == 8
1182 oa[3] = xor_block(oa[2], ctx->L[2]);
1183 ta[3] = xor_block(oa[3], ctp[3]);
1184 oa[4] = xor_block(oa[1], ctx->L[2]);
1185 ta[4] = xor_block(oa[4], ctp[4]);
1186 oa[5] = xor_block(oa[0], ctx->L[2]);
1187 ta[5] = xor_block(oa[5], ctp[5]);
1188 oa[6] = xor_block(oa[7], ctx->L[2]);
1189 ta[6] = xor_block(oa[6], ctp[6]);
1190 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1191 ta[7] = xor_block(oa[7], ctp[7]);
1192 #endif
1193 AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1194 ptp[0] = xor_block(ta[0], oa[0]);
1195 checksum = xor_block(checksum, ptp[0]);
1196 ptp[1] = xor_block(ta[1], oa[1]);
1197 checksum = xor_block(checksum, ptp[1]);
1198 ptp[2] = xor_block(ta[2], oa[2]);
1199 checksum = xor_block(checksum, ptp[2]);
1200 ptp[3] = xor_block(ta[3], oa[3]);
1201 checksum = xor_block(checksum, ptp[3]);
1202 #if (BPI == 8)
1203 ptp[4] = xor_block(ta[4], oa[4]);
1204 checksum = xor_block(checksum, ptp[4]);
1205 ptp[5] = xor_block(ta[5], oa[5]);
1206 checksum = xor_block(checksum, ptp[5]);
1207 ptp[6] = xor_block(ta[6], oa[6]);
1208 checksum = xor_block(checksum, ptp[6]);
1209 ptp[7] = xor_block(ta[7], oa[7]);
1210 checksum = xor_block(checksum, ptp[7]);
1211 #endif
1212 ptp += BPI;
1213 ctp += BPI;
1214 } while (--i);
1215 ctx->offset = offset = oa[BPI - 1];
1216 ctx->blocks_processed = block_num;
1217 ctx->checksum = checksum;
1218 }
1219
1220 if (final) {
1221 block ta[BPI + 1], oa[BPI];
1222
1223 /* Process remaining plaintext and compute its tag contribution */
1224 unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1225 k = 0; /* How many blocks in ta[] need ECBing */
1226 if (remaining) {
1227 #if (BPI == 8)
1228 if (remaining >= 64) {
1229 oa[0] = xor_block(offset, ctx->L[0]);
1230 ta[0] = xor_block(oa[0], ctp[0]);
1231 oa[1] = xor_block(oa[0], ctx->L[1]);
1232 ta[1] = xor_block(oa[1], ctp[1]);
1233 oa[2] = xor_block(oa[1], ctx->L[0]);
1234 ta[2] = xor_block(oa[2], ctp[2]);
1235 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1236 ta[3] = xor_block(offset, ctp[3]);
1237 remaining -= 64;
1238 k = 4;
1239 }
1240 #endif
1241 if (remaining >= 32) {
1242 oa[k] = xor_block(offset, ctx->L[0]);
1243 ta[k] = xor_block(oa[k], ctp[k]);
1244 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1245 ta[k + 1] = xor_block(offset, ctp[k + 1]);
1246 remaining -= 32;
1247 k += 2;
1248 }
1249 if (remaining >= 16) {
1250 offset = oa[k] = xor_block(offset, ctx->L[0]);
1251 ta[k] = xor_block(offset, ctp[k]);
1252 remaining -= 16;
1253 ++k;
1254 }
1255 if (remaining) {
1256 block pad;
1257 offset = xor_block(offset, ctx->Lstar);
1258 AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1259 pad = tmp.bl;
1260 memcpy(tmp.u8, ctp + k, remaining);
1261 tmp.bl = xor_block(tmp.bl, pad);
1262 tmp.u8[remaining] = (unsigned char)0x80u;
1263 memcpy(ptp + k, tmp.u8, remaining);
1264 checksum = xor_block(checksum, tmp.bl);
1265 }
1266 }
1267 AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1268 switch (k) {
1269 #if (BPI == 8)
1270 case 7:
1271 ptp[6] = xor_block(ta[6], oa[6]);
1272 checksum = xor_block(checksum, ptp[6]);
1273 case 6:
1274 ptp[5] = xor_block(ta[5], oa[5]);
1275 checksum = xor_block(checksum, ptp[5]);
1276 case 5:
1277 ptp[4] = xor_block(ta[4], oa[4]);
1278 checksum = xor_block(checksum, ptp[4]);
1279 case 4:
1280 ptp[3] = xor_block(ta[3], oa[3]);
1281 checksum = xor_block(checksum, ptp[3]);
1282 #endif
1283 case 3:
1284 ptp[2] = xor_block(ta[2], oa[2]);
1285 checksum = xor_block(checksum, ptp[2]);
1286 case 2:
1287 ptp[1] = xor_block(ta[1], oa[1]);
1288 checksum = xor_block(checksum, ptp[1]);
1289 case 1:
1290 ptp[0] = xor_block(ta[0], oa[0]);
1291 checksum = xor_block(checksum, ptp[0]);
1292 }
1293
1294 /* Calculate expected tag */
1295 offset = xor_block(offset, ctx->Ldollar);
1296 tmp.bl = xor_block(offset, checksum);
1297 AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1298 tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1299
1300 /* Compare with proposed tag, change ct_len if invalid */
1301 if ((OCB_TAG_LEN == 16) && tag) {
1302 if (unequal_blocks(tmp.bl, *(block*)tag))
1303 ct_len = AE_INVALID;
1304 } else {
1305 #if (OCB_TAG_LEN > 0)
1306 int len = OCB_TAG_LEN;
1307 #else
1308 int len = ctx->tag_len;
1309 #endif
1310 if (tag) {
1311 if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1312 ct_len = AE_INVALID;
1313 } else {
1314 if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1315 ct_len = AE_INVALID;
1316 }
1317 }
1318 }
1319 return ct_len;
1320 }
1321
1322 /* ----------------------------------------------------------------------- */
1323 /* Simple test program */
1324 /* ----------------------------------------------------------------------- */
1325
1326 #if 0
1327
1328 #include <stdio.h>
1329 #include <time.h>
1330
1331 #if __GNUC__
1332 #define ALIGN(n) __attribute__((aligned(n)))
1333 #elif _MSC_VER
1334 #define ALIGN(n) __declspec(align(n))
1335 #else /* Not GNU/Microsoft: delete alignment uses. */
1336 #define ALIGN(n)
1337 #endif
1338
1339 static void pbuf(void *p, unsigned len, const void *s)
1340 {
1341 unsigned i;
1342 if (s)
1343 printf("%s", (char *)s);
1344 for (i = 0; i < len; i++)
1345 printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1346 printf("\n");
1347 }
1348
1349 static void vectors(ae_ctx *ctx, int len)
1350 {
1351 ALIGN(16) char pt[128];
1352 ALIGN(16) char ct[144];
1353 ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1354 int i;
1355 for (i=0; i < 128; i++) pt[i] = i;
1356 i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1357 printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1358 i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1359 printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1360 i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1361 printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1362 }
1363
1364 void validate()
1365 {
1366 ALIGN(16) char pt[1024];
1367 ALIGN(16) char ct[1024];
1368 ALIGN(16) char tag[16];
1369 ALIGN(16) char nonce[12] = {0,};
1370 ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1371 ae_ctx ctx;
1372 char *val_buf, *next;
1373 int i, len;
1374
1375 val_buf = (char *)malloc(22400 + 16);
1376 next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1377
1378 if (0) {
1379 ae_init(&ctx, key, 16, 12, 16);
1380 /* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1381 vectors(&ctx,0);
1382 vectors(&ctx,8);
1383 vectors(&ctx,16);
1384 vectors(&ctx,24);
1385 vectors(&ctx,32);
1386 vectors(&ctx,40);
1387 }
1388
1389 memset(key,0,32);
1390 memset(pt,0,128);
1391 ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1392
1393 /* RFC Vector test */
1394 for (i = 0; i < 128; i++) {
1395 int first = ((i/3)/(BPI*16))*(BPI*16);
1396 int second = first;
1397 int third = i - (first + second);
1398
1399 nonce[11] = i;
1400
1401 if (0) {
1402 ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1403 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1404 next = next+i+OCB_TAG_LEN;
1405
1406 ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1407 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1408 next = next+i+OCB_TAG_LEN;
1409
1410 ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1411 memcpy(next,ct,OCB_TAG_LEN);
1412 next = next+OCB_TAG_LEN;
1413 } else {
1414 ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1415 ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1416 ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1417 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1418 next = next+i+OCB_TAG_LEN;
1419
1420 ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1421 ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1422 ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1423 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1424 next = next+i+OCB_TAG_LEN;
1425
1426 ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1427 ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1428 ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1429 memcpy(next,ct,OCB_TAG_LEN);
1430 next = next+OCB_TAG_LEN;
1431 }
1432
1433 }
1434 nonce[11] = 0;
1435 ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1436 pbuf(tag,OCB_TAG_LEN,0);
1437
1438
1439 /* Encrypt/Decrypt test */
1440 for (i = 0; i < 128; i++) {
1441 int first = ((i/3)/(BPI*16))*(BPI*16);
1442 int second = first;
1443 int third = i - (first + second);
1444
1445 nonce[11] = i%128;
1446
1447 if (1) {
1448 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1449 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1450 len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1451 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1452 if (len != i) { printf("Length error: %d\n", i); return; }
1453 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1454 } else {
1455 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1456 ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1457 ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1458 len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1459 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1460 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1461 }
1462
1463 }
1464 printf("Decrypt: PASS\n");
1465 }
1466
1467 int main()
1468 {
1469 validate();
1470 return 0;
1471 }
1472 #endif
1473
1474 #if USE_AES_NI
1475 char infoString[] = "OCB3 (AES-NI)";
1476 #elif USE_REFERENCE_AES
1477 char infoString[] = "OCB3 (Reference)";
1478 #elif USE_OPENSSL_AES
1479 char infoString[] = "OCB3 (OpenSSL)";
1480 #endif
1481