1 /*------------------------------------------------------------------------
2 / OCB Version 3 Reference Code (Optimized C) Last modified 12-JUN-2013
3 /-------------------------------------------------------------------------
4 / Copyright (c) 2013 Ted Krovetz.
5 /
6 / Permission to use, copy, modify, and/or distribute this software for any
7 / purpose with or without fee is hereby granted, provided that the above
8 / copyright notice and this permission notice appear in all copies.
9 /
10 / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 /
18 / Phillip Rogaway holds patents relevant to OCB. See the following for
19 / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20 /
21 / Special thanks to Keegan McAllister for suggesting several good improvements
22 /
23 / Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24 /------------------------------------------------------------------------- */
25
26 /* ----------------------------------------------------------------------- */
27 /* Usage notes */
28 /* ----------------------------------------------------------------------- */
29
30 /* - When AE_PENDING is passed as the 'final' parameter of any function,
31 / the length parameters must be a multiple of (BPI*16).
32 / - When available, SSE or AltiVec registers are used to manipulate data.
33 / So, when on machines with these facilities, all pointers passed to
34 / any function should be 16-byte aligned.
35 / - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36 / encrypted in-place), but no other pair of pointers may be equal.
37 / - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38 / when compiling under MSVC. If untrue, alter the #define.
39 / - This code is tested for C99 and recent versions of GCC and MSVC. */
40
41 /* ----------------------------------------------------------------------- */
42 /* User configuration options */
43 /* ----------------------------------------------------------------------- */
44
45 /* Set the AES key length to use and length of authentication tag to produce.
46 / Setting either to 0 requires the value be set at runtime via ae_init().
47 / Some optimizations occur for each when set to a fixed value. */
48 #define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49 #define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init */
50
51 /* This implementation has built-in support for multiple AES APIs. Set any
52 / one of the following to non-zero to specify which to use. */
53 #define USE_OPENSSL_AES 1 /* http://openssl.org */
54 #define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c */
55 #define USE_AES_NI 0 /* Uses compiler's intrinsics */
56
57 /* During encryption and decryption, various "L values" are required.
58 / The L values can be precomputed during initialization (requiring extra
59 / space in ae_ctx), generated as needed (slightly slowing encryption and
60 / decryption), or some combination of the two. L_TABLE_SZ specifies how many
61 / L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62 / are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63 / 2^L_TABLE_SZ blocks need no L values calculated dynamically. */
64 #define L_TABLE_SZ 16
65
66 /* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67 / will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68 / in better performance. */
69 #define L_TABLE_SZ_IS_ENOUGH 1
70
71 /* ----------------------------------------------------------------------- */
72 /* Includes and compiler specific definitions */
73 /* ----------------------------------------------------------------------- */
74
75 #include "ae.h"
76 #include <stdlib.h>
77 #include <string.h>
78
79 /* Define standard sized integers */
80 #if defined(_MSC_VER) && (_MSC_VER < 1600)
81 typedef unsigned __int8 uint8_t;
82 typedef unsigned __int32 uint32_t;
83 typedef unsigned __int64 uint64_t;
84 typedef __int64 int64_t;
85 #else
86 #include <stdint.h>
87 #endif
88
89 /* Compiler-specific intrinsics and fixes: bswap64, ntz */
90 #if _MSC_VER
91 #define inline __inline /* MSVC doesn't recognize "inline" in C */
92 #define restrict __restrict /* MSVC doesn't recognize "restrict" in C */
93 #define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSE2 */
94 #define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95 #include <intrin.h>
96 #pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97 #define bswap64(x) _byteswap_uint64(x)
ntz(unsigned x)98 static inline unsigned ntz(unsigned x) {
99 _BitScanForward(&x, x);
100 return x;
101 }
102 #elif __GNUC__
103 #define inline __inline__ /* No "inline" in GCC ansi C mode */
104 #define restrict __restrict__ /* No "restrict" in GCC ansi C mode */
105 #define bswap64(x) __builtin_bswap64(x) /* Assuming GCC 4.3+ */
106 #define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107 #else /* Assume some C99 features: stdint.h, inline, restrict */
108 #define bswap32(x) \
109 ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) | \
110 (((x)&0x000000ffu) << 24))
111
bswap64(uint64_t x)112 static inline uint64_t bswap64(uint64_t x) {
113 union {
114 uint64_t u64;
115 uint32_t u32[2];
116 } in, out;
117 in.u64 = x;
118 out.u32[0] = bswap32(in.u32[1]);
119 out.u32[1] = bswap32(in.u32[0]);
120 return out.u64;
121 }
122
123 #if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
ntz(unsigned x)124 static inline unsigned ntz(unsigned x) {
125 static const unsigned char tz_table[] = {
126 0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131 return tz_table[x / 4];
132 }
133 #else /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
ntz(unsigned x)134 static inline unsigned ntz(unsigned x) {
135 static const unsigned char tz_table[32] = {0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20,
136 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19,
137 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
138 return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139 }
140 #endif
141 #endif
142
143 /* ----------------------------------------------------------------------- */
144 /* Define blocks and operations -- Patch if incorrect on your compiler. */
145 /* ----------------------------------------------------------------------- */
146
147 #if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148 #include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149 #include <emmintrin.h> /* SSE2 instructions */
150 typedef __m128i block;
151 #define xor_block(x, y) _mm_xor_si128(x, y)
152 #define zero_block() _mm_setzero_si128()
153 #define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154 #if __SSSE3__ || USE_AES_NI
155 #include <tmmintrin.h> /* SSSE3 instructions */
156 #define swap_if_le(b) \
157 _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158 #else
swap_if_le(block b)159 static inline block swap_if_le(block b) {
160 block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161 a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162 a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163 return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164 }
165 #endif
gen_offset(uint64_t KtopStr[3],unsigned bot)166 static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167 block hi = _mm_load_si128((__m128i*)(KtopStr + 0)); /* hi = B A */
168 block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169 __m128i lshift = _mm_cvtsi32_si128(bot);
170 __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171 lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172 #if __SSSE3__ || USE_AES_NI
173 return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174 #else
175 return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176 #endif
177 }
double_block(block bl)178 static inline block double_block(block bl) {
179 const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180 __m128i tmp = _mm_srai_epi32(bl, 31);
181 tmp = _mm_and_si128(tmp, mask);
182 tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183 bl = _mm_slli_epi32(bl, 1);
184 return _mm_xor_si128(bl, tmp);
185 }
186 #elif __ALTIVEC__
187 #include <altivec.h>
188 typedef vector unsigned block;
189 #define xor_block(x, y) vec_xor(x, y)
190 #define zero_block() vec_splat_u32(0)
191 #define unequal_blocks(x, y) vec_any_ne(x, y)
192 #define swap_if_le(b) (b)
193 #if __PPC64__
gen_offset(uint64_t KtopStr[3],unsigned bot)194 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195 union {
196 uint64_t u64[2];
197 block bl;
198 } rval;
199 rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200 rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201 return rval.bl;
202 }
203 #else
204 /* Special handling: Shifts are mod 32, and no 64-bit types */
gen_offset(uint64_t KtopStr[3],unsigned bot)205 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206 const vector unsigned k32 = {32, 32, 32, 32};
207 vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208 vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209 vector unsigned bot_vec;
210 if (bot < 32) {
211 lo = vec_sld(hi, lo, 4);
212 } else {
213 vector unsigned t = vec_sld(hi, lo, 4);
214 lo = vec_sld(hi, lo, 8);
215 hi = t;
216 bot = bot - 32;
217 }
218 if (bot == 0)
219 return hi;
220 *(unsigned*)&bot_vec = bot;
221 vector unsigned lshift = vec_splat(bot_vec, 0);
222 vector unsigned rshift = vec_sub(k32, lshift);
223 hi = vec_sl(hi, lshift);
224 lo = vec_sr(lo, rshift);
225 return vec_xor(hi, lo);
226 }
227 #endif
double_block(block b)228 static inline block double_block(block b) {
229 const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230 const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231 const vector unsigned char shift7 = vec_splat_u8(7);
232 const vector unsigned char shift1 = vec_splat_u8(1);
233 vector unsigned char c = (vector unsigned char)b;
234 vector unsigned char t = vec_sra(c, shift7);
235 t = vec_and(t, mask);
236 t = vec_perm(t, t, perm);
237 c = vec_sl(c, shift1);
238 return (block)vec_xor(c, t);
239 }
240 #elif __ARM_NEON__
241 #include <arm_neon.h>
242 typedef int8x16_t block; /* Yay! Endian-neutral reads! */
243 #define xor_block(x, y) veorq_s8(x, y)
244 #define zero_block() vdupq_n_s8(0)
unequal_blocks(block a,block b)245 static inline int unequal_blocks(block a, block b) {
246 int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247 return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248 }
249 #define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)251 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252 const union {
253 unsigned x;
254 unsigned char endian;
255 } little = {1};
256 const int64x2_t k64 = {-64, -64};
257 /* Copy hi and lo into local variables to ensure proper alignment */
258 uint64x2_t hi = vld1q_u64(KtopStr + 0); /* hi = A B */
259 uint64x2_t lo = vld1q_u64(KtopStr + 1); /* lo = B C */
260 int64x2_t ls = vdupq_n_s64(bot);
261 int64x2_t rs = vqaddq_s64(k64, ls);
262 block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
263 if (little.endian)
264 rval = vrev64q_s8(rval);
265 return rval;
266 }
double_block(block b)267 static inline block double_block(block b) {
268 const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
269 block tmp = vshrq_n_s8(b, 7);
270 tmp = vandq_s8(tmp, mask);
271 tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
272 b = vshlq_n_s8(b, 1);
273 return veorq_s8(tmp, b);
274 }
275 #else
276 typedef struct { uint64_t l, r; } block;
xor_block(block x,block y)277 static inline block xor_block(block x, block y) {
278 x.l ^= y.l;
279 x.r ^= y.r;
280 return x;
281 }
zero_block(void)282 static inline block zero_block(void) {
283 const block t = {0, 0};
284 return t;
285 }
286 #define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
swap_if_le(block b)287 static inline block swap_if_le(block b) {
288 const union {
289 unsigned x;
290 unsigned char endian;
291 } little = {1};
292 if (little.endian) {
293 block r;
294 r.l = bswap64(b.l);
295 r.r = bswap64(b.r);
296 return r;
297 } else
298 return b;
299 }
300
301 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)302 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
303 block rval;
304 if (bot != 0) {
305 rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
306 rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
307 } else {
308 rval.l = KtopStr[0];
309 rval.r = KtopStr[1];
310 }
311 return swap_if_le(rval);
312 }
313
314 #if __GNUC__ && __arm__
double_block(block b)315 static inline block double_block(block b) {
316 __asm__("adds %1,%1,%1\n\t"
317 "adcs %H1,%H1,%H1\n\t"
318 "adcs %0,%0,%0\n\t"
319 "adcs %H0,%H0,%H0\n\t"
320 "it cs\n\t"
321 "eorcs %1,%1,#135"
322 : "+r"(b.l), "+r"(b.r)
323 :
324 : "cc");
325 return b;
326 }
327 #else
double_block(block b)328 static inline block double_block(block b) {
329 uint64_t t = (uint64_t)((int64_t)b.l >> 63);
330 b.l = (b.l + b.l) ^ (b.r >> 63);
331 b.r = (b.r + b.r) ^ (t & 135);
332 return b;
333 }
334 #endif
335
336 #endif
337
338 /* ----------------------------------------------------------------------- */
339 /* AES - Code uses OpenSSL API. Other implementations get mapped to it. */
340 /* ----------------------------------------------------------------------- */
341
342 /*---------------*/
343 #if USE_OPENSSL_AES
344 /*---------------*/
345
346 #include <openssl/aes.h> /* http://openssl.org/ */
347
348 /* How to ECB encrypt an array of blocks, in place */
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)349 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
350 while (nblks) {
351 --nblks;
352 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
353 }
354 }
355
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)356 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
357 while (nblks) {
358 --nblks;
359 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
360 }
361 }
362
363 #define BPI 4 /* Number of blocks in buffer per ECB call */
364
365 /*-------------------*/
366 #elif USE_REFERENCE_AES
367 /*-------------------*/
368
369 #include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
370 #if (OCB_KEY_LEN == 0)
371 typedef struct {
372 uint32_t rd_key[60];
373 int rounds;
374 } AES_KEY;
375 #define ROUNDS(ctx) ((ctx)->rounds)
376 #define AES_set_encrypt_key(x, y, z) \
377 do { \
378 rijndaelKeySetupEnc((z)->rd_key, x, y); \
379 (z)->rounds = y / 32 + 6; \
380 } while (0)
381 #define AES_set_decrypt_key(x, y, z) \
382 do { \
383 rijndaelKeySetupDec((z)->rd_key, x, y); \
384 (z)->rounds = y / 32 + 6; \
385 } while (0)
386 #else
387 typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
388 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
389 #define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
390 #define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
391 #endif
392 #define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
393 #define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
394
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)395 static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
396 while (nblks) {
397 --nblks;
398 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
399 }
400 }
401
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)402 void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
403 while (nblks) {
404 --nblks;
405 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
406 }
407 }
408
409 #define BPI 4 /* Number of blocks in buffer per ECB call */
410
411 /*----------*/
412 #elif USE_AES_NI
413 /*----------*/
414
415 #include <wmmintrin.h>
416
417 #if (OCB_KEY_LEN == 0)
418 typedef struct {
419 __m128i rd_key[15];
420 int rounds;
421 } AES_KEY;
422 #define ROUNDS(ctx) ((ctx)->rounds)
423 #else
424 typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
425 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
426 #endif
427
428 #define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const) \
429 v2 = _mm_aeskeygenassist_si128(v4, aes_const); \
430 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16)); \
431 v1 = _mm_xor_si128(v1, v3); \
432 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140)); \
433 v1 = _mm_xor_si128(v1, v3); \
434 v2 = _mm_shuffle_epi32(v2, shuff_const); \
435 v1 = _mm_xor_si128(v1, v2)
436
437 #define EXPAND192_STEP(idx, aes_const) \
438 EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const); \
439 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
440 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
441 kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68)); \
442 kp[idx + 1] = \
443 _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78)); \
444 EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2)); \
445 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
446 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
447 kp[idx + 2] = x0; \
448 tmp = x3
449
AES_128_Key_Expansion(const unsigned char * userkey,void * key)450 static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
451 __m128i x0, x1, x2;
452 __m128i* kp = (__m128i*)key;
453 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
454 x2 = _mm_setzero_si128();
455 EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
456 kp[1] = x0;
457 EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
458 kp[2] = x0;
459 EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
460 kp[3] = x0;
461 EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
462 kp[4] = x0;
463 EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
464 kp[5] = x0;
465 EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
466 kp[6] = x0;
467 EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
468 kp[7] = x0;
469 EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
470 kp[8] = x0;
471 EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
472 kp[9] = x0;
473 EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
474 kp[10] = x0;
475 }
476
AES_192_Key_Expansion(const unsigned char * userkey,void * key)477 static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
478 __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
479 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
480 tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
481 x2 = _mm_setzero_si128();
482 EXPAND192_STEP(1, 1);
483 EXPAND192_STEP(4, 4);
484 EXPAND192_STEP(7, 16);
485 EXPAND192_STEP(10, 64);
486 }
487
AES_256_Key_Expansion(const unsigned char * userkey,void * key)488 static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
489 __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
490 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
491 kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
492 x2 = _mm_setzero_si128();
493 EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
494 kp[2] = x0;
495 EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
496 kp[3] = x3;
497 EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
498 kp[4] = x0;
499 EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
500 kp[5] = x3;
501 EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
502 kp[6] = x0;
503 EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
504 kp[7] = x3;
505 EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
506 kp[8] = x0;
507 EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
508 kp[9] = x3;
509 EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
510 kp[10] = x0;
511 EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
512 kp[11] = x3;
513 EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
514 kp[12] = x0;
515 EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
516 kp[13] = x3;
517 EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
518 kp[14] = x0;
519 }
520
AES_set_encrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)521 static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
522 if (bits == 128) {
523 AES_128_Key_Expansion(userKey, key);
524 } else if (bits == 192) {
525 AES_192_Key_Expansion(userKey, key);
526 } else if (bits == 256) {
527 AES_256_Key_Expansion(userKey, key);
528 }
529 #if (OCB_KEY_LEN == 0)
530 key->rounds = 6 + bits / 32;
531 #endif
532 return 0;
533 }
534
AES_set_decrypt_key_fast(AES_KEY * dkey,const AES_KEY * ekey)535 static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
536 int j = 0;
537 int i = ROUNDS(ekey);
538 #if (OCB_KEY_LEN == 0)
539 dkey->rounds = i;
540 #endif
541 dkey->rd_key[i--] = ekey->rd_key[j++];
542 while (i)
543 dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
544 dkey->rd_key[i] = ekey->rd_key[j];
545 }
546
AES_set_decrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)547 static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
548 AES_KEY temp_key;
549 AES_set_encrypt_key(userKey, bits, &temp_key);
550 AES_set_decrypt_key_fast(key, &temp_key);
551 return 0;
552 }
553
AES_encrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)554 static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
555 int j, rnds = ROUNDS(key);
556 const __m128i* sched = ((__m128i*)(key->rd_key));
557 __m128i tmp = _mm_load_si128((__m128i*)in);
558 tmp = _mm_xor_si128(tmp, sched[0]);
559 for (j = 1; j < rnds; j++)
560 tmp = _mm_aesenc_si128(tmp, sched[j]);
561 tmp = _mm_aesenclast_si128(tmp, sched[j]);
562 _mm_store_si128((__m128i*)out, tmp);
563 }
564
AES_decrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)565 static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
566 int j, rnds = ROUNDS(key);
567 const __m128i* sched = ((__m128i*)(key->rd_key));
568 __m128i tmp = _mm_load_si128((__m128i*)in);
569 tmp = _mm_xor_si128(tmp, sched[0]);
570 for (j = 1; j < rnds; j++)
571 tmp = _mm_aesdec_si128(tmp, sched[j]);
572 tmp = _mm_aesdeclast_si128(tmp, sched[j]);
573 _mm_store_si128((__m128i*)out, tmp);
574 }
575
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)576 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
577 unsigned i, j, rnds = ROUNDS(key);
578 const __m128i* sched = ((__m128i*)(key->rd_key));
579 for (i = 0; i < nblks; ++i)
580 blks[i] = _mm_xor_si128(blks[i], sched[0]);
581 for (j = 1; j < rnds; ++j)
582 for (i = 0; i < nblks; ++i)
583 blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
584 for (i = 0; i < nblks; ++i)
585 blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
586 }
587
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)588 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
589 unsigned i, j, rnds = ROUNDS(key);
590 const __m128i* sched = ((__m128i*)(key->rd_key));
591 for (i = 0; i < nblks; ++i)
592 blks[i] = _mm_xor_si128(blks[i], sched[0]);
593 for (j = 1; j < rnds; ++j)
594 for (i = 0; i < nblks; ++i)
595 blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
596 for (i = 0; i < nblks; ++i)
597 blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
598 }
599
600 #define BPI 8 /* Number of blocks in buffer per ECB call */
601 /* Set to 4 for Westmere, 8 for Sandy Bridge */
602
603 #endif
604
605 /* ----------------------------------------------------------------------- */
606 /* Define OCB context structure. */
607 /* ----------------------------------------------------------------------- */
608
609 /*------------------------------------------------------------------------
610 / Each item in the OCB context is stored either "memory correct" or
611 / "register correct". On big-endian machines, this is identical. On
612 / little-endian machines, one must choose whether the byte-string
613 / is in the correct order when it resides in memory or in registers.
614 / It must be register correct whenever it is to be manipulated
615 / arithmetically, but must be memory correct whenever it interacts
616 / with the plaintext or ciphertext.
617 /------------------------------------------------------------------------- */
618
619 struct _ae_ctx {
620 block offset; /* Memory correct */
621 block checksum; /* Memory correct */
622 block Lstar; /* Memory correct */
623 block Ldollar; /* Memory correct */
624 block L[L_TABLE_SZ]; /* Memory correct */
625 block ad_checksum; /* Memory correct */
626 block ad_offset; /* Memory correct */
627 block cached_Top; /* Memory correct */
628 uint64_t KtopStr[3]; /* Register correct, each item */
629 uint32_t ad_blocks_processed;
630 uint32_t blocks_processed;
631 AES_KEY decrypt_key;
632 AES_KEY encrypt_key;
633 #if (OCB_TAG_LEN == 0)
634 unsigned tag_len;
635 #endif
636 };
637
638 /* ----------------------------------------------------------------------- */
639 /* L table lookup (or on-the-fly generation) */
640 /* ----------------------------------------------------------------------- */
641
642 #if L_TABLE_SZ_IS_ENOUGH
643 #define getL(_ctx, _tz) ((_ctx)->L[_tz])
644 #else
getL(const ae_ctx * ctx,unsigned tz)645 static block getL(const ae_ctx* ctx, unsigned tz) {
646 if (tz < L_TABLE_SZ)
647 return ctx->L[tz];
648 else {
649 unsigned i;
650 /* Bring L[MAX] into registers, make it register correct */
651 block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
652 rval = double_block(rval);
653 for (i = L_TABLE_SZ; i < tz; i++)
654 rval = double_block(rval);
655 return swap_if_le(rval); /* To memory correct */
656 }
657 }
658 #endif
659
660 /* ----------------------------------------------------------------------- */
661 /* Public functions */
662 /* ----------------------------------------------------------------------- */
663
664 /* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
665 on 16-byte alignments. (I believe all major 64-bit systems do already.) */
666
ae_allocate(void * misc)667 ae_ctx* ae_allocate(void* misc) {
668 void* p;
669 (void)misc; /* misc unused in this implementation */
670 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
671 p = _mm_malloc(sizeof(ae_ctx), 16);
672 #elif(__ALTIVEC__ && !__PPC64__)
673 if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
674 p = NULL;
675 #else
676 p = malloc(sizeof(ae_ctx));
677 #endif
678 return (ae_ctx*)p;
679 }
680
ae_free(ae_ctx * ctx)681 void ae_free(ae_ctx* ctx) {
682 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
683 _mm_free(ctx);
684 #else
685 free(ctx);
686 #endif
687 }
688
689 /* ----------------------------------------------------------------------- */
690
ae_clear(ae_ctx * ctx)691 int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization */
692 {
693 memset(ctx, 0, sizeof(ae_ctx));
694 return AE_SUCCESS;
695 }
696
ae_ctx_sizeof(void)697 int ae_ctx_sizeof(void) {
698 return (int)sizeof(ae_ctx);
699 }
700
701 /* ----------------------------------------------------------------------- */
702
ae_init(ae_ctx * ctx,const void * key,int key_len,int nonce_len,int tag_len)703 int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
704 unsigned i;
705 block tmp_blk;
706
707 if (nonce_len != 12)
708 return AE_NOT_SUPPORTED;
709
710 /* Initialize encryption & decryption keys */
711 #if (OCB_KEY_LEN > 0)
712 key_len = OCB_KEY_LEN;
713 #endif
714 AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
715 #if USE_AES_NI
716 AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
717 #else
718 AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
719 #endif
720
721 /* Zero things that need zeroing */
722 ctx->cached_Top = ctx->ad_checksum = zero_block();
723 ctx->ad_blocks_processed = 0;
724
725 /* Compute key-dependent values */
726 AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
727 tmp_blk = swap_if_le(ctx->Lstar);
728 tmp_blk = double_block(tmp_blk);
729 ctx->Ldollar = swap_if_le(tmp_blk);
730 tmp_blk = double_block(tmp_blk);
731 ctx->L[0] = swap_if_le(tmp_blk);
732 for (i = 1; i < L_TABLE_SZ; i++) {
733 tmp_blk = double_block(tmp_blk);
734 ctx->L[i] = swap_if_le(tmp_blk);
735 }
736
737 #if (OCB_TAG_LEN == 0)
738 ctx->tag_len = tag_len;
739 #else
740 (void)tag_len; /* Suppress var not used error */
741 #endif
742
743 return AE_SUCCESS;
744 }
745
746 /* ----------------------------------------------------------------------- */
747
gen_offset_from_nonce(ae_ctx * ctx,const void * nonce)748 static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
749 const union {
750 unsigned x;
751 unsigned char endian;
752 } little = {1};
753 union {
754 uint32_t u32[4];
755 uint8_t u8[16];
756 block bl;
757 } tmp;
758 unsigned idx;
759
760 /* Replace cached nonce Top if needed */
761 #if (OCB_TAG_LEN > 0)
762 if (little.endian)
763 tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
764 else
765 tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
766 #else
767 if (little.endian)
768 tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
769 else
770 tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
771 #endif
772 tmp.u32[1] = ((uint32_t*)nonce)[0];
773 tmp.u32[2] = ((uint32_t*)nonce)[1];
774 tmp.u32[3] = ((uint32_t*)nonce)[2];
775 idx = (unsigned)(tmp.u8[15] & 0x3f); /* Get low 6 bits of nonce */
776 tmp.u8[15] = tmp.u8[15] & 0xc0; /* Zero low 6 bits of nonce */
777 if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached? */
778 ctx->cached_Top = tmp.bl; /* Update cache, KtopStr */
779 AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
780 if (little.endian) { /* Make Register Correct */
781 ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
782 ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
783 }
784 ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
785 }
786 return gen_offset(ctx->KtopStr, idx);
787 }
788
process_ad(ae_ctx * ctx,const void * ad,int ad_len,int final)789 static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
790 union {
791 uint32_t u32[4];
792 uint8_t u8[16];
793 block bl;
794 } tmp;
795 block ad_offset, ad_checksum;
796 const block* adp = (block*)ad;
797 unsigned i, k, tz, remaining;
798
799 ad_offset = ctx->ad_offset;
800 ad_checksum = ctx->ad_checksum;
801 i = ad_len / (BPI * 16);
802 if (i) {
803 unsigned ad_block_num = ctx->ad_blocks_processed;
804 do {
805 block ta[BPI], oa[BPI];
806 ad_block_num += BPI;
807 tz = ntz(ad_block_num);
808 oa[0] = xor_block(ad_offset, ctx->L[0]);
809 ta[0] = xor_block(oa[0], adp[0]);
810 oa[1] = xor_block(oa[0], ctx->L[1]);
811 ta[1] = xor_block(oa[1], adp[1]);
812 oa[2] = xor_block(ad_offset, ctx->L[1]);
813 ta[2] = xor_block(oa[2], adp[2]);
814 #if BPI == 4
815 ad_offset = xor_block(oa[2], getL(ctx, tz));
816 ta[3] = xor_block(ad_offset, adp[3]);
817 #elif BPI == 8
818 oa[3] = xor_block(oa[2], ctx->L[2]);
819 ta[3] = xor_block(oa[3], adp[3]);
820 oa[4] = xor_block(oa[1], ctx->L[2]);
821 ta[4] = xor_block(oa[4], adp[4]);
822 oa[5] = xor_block(oa[0], ctx->L[2]);
823 ta[5] = xor_block(oa[5], adp[5]);
824 oa[6] = xor_block(ad_offset, ctx->L[2]);
825 ta[6] = xor_block(oa[6], adp[6]);
826 ad_offset = xor_block(oa[6], getL(ctx, tz));
827 ta[7] = xor_block(ad_offset, adp[7]);
828 #endif
829 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
830 ad_checksum = xor_block(ad_checksum, ta[0]);
831 ad_checksum = xor_block(ad_checksum, ta[1]);
832 ad_checksum = xor_block(ad_checksum, ta[2]);
833 ad_checksum = xor_block(ad_checksum, ta[3]);
834 #if (BPI == 8)
835 ad_checksum = xor_block(ad_checksum, ta[4]);
836 ad_checksum = xor_block(ad_checksum, ta[5]);
837 ad_checksum = xor_block(ad_checksum, ta[6]);
838 ad_checksum = xor_block(ad_checksum, ta[7]);
839 #endif
840 adp += BPI;
841 } while (--i);
842 ctx->ad_blocks_processed = ad_block_num;
843 ctx->ad_offset = ad_offset;
844 ctx->ad_checksum = ad_checksum;
845 }
846
847 if (final) {
848 block ta[BPI];
849
850 /* Process remaining associated data, compute its tag contribution */
851 remaining = ((unsigned)ad_len) % (BPI * 16);
852 if (remaining) {
853 k = 0;
854 #if (BPI == 8)
855 if (remaining >= 64) {
856 tmp.bl = xor_block(ad_offset, ctx->L[0]);
857 ta[0] = xor_block(tmp.bl, adp[0]);
858 tmp.bl = xor_block(tmp.bl, ctx->L[1]);
859 ta[1] = xor_block(tmp.bl, adp[1]);
860 ad_offset = xor_block(ad_offset, ctx->L[1]);
861 ta[2] = xor_block(ad_offset, adp[2]);
862 ad_offset = xor_block(ad_offset, ctx->L[2]);
863 ta[3] = xor_block(ad_offset, adp[3]);
864 remaining -= 64;
865 k = 4;
866 }
867 #endif
868 if (remaining >= 32) {
869 ad_offset = xor_block(ad_offset, ctx->L[0]);
870 ta[k] = xor_block(ad_offset, adp[k]);
871 ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
872 ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
873 remaining -= 32;
874 k += 2;
875 }
876 if (remaining >= 16) {
877 ad_offset = xor_block(ad_offset, ctx->L[0]);
878 ta[k] = xor_block(ad_offset, adp[k]);
879 remaining = remaining - 16;
880 ++k;
881 }
882 if (remaining) {
883 ad_offset = xor_block(ad_offset, ctx->Lstar);
884 tmp.bl = zero_block();
885 memcpy(tmp.u8, adp + k, remaining);
886 tmp.u8[remaining] = (unsigned char)0x80u;
887 ta[k] = xor_block(ad_offset, tmp.bl);
888 ++k;
889 }
890 AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
891 switch (k) {
892 #if (BPI == 8)
893 case 8:
894 ad_checksum = xor_block(ad_checksum, ta[7]);
895 case 7:
896 ad_checksum = xor_block(ad_checksum, ta[6]);
897 case 6:
898 ad_checksum = xor_block(ad_checksum, ta[5]);
899 case 5:
900 ad_checksum = xor_block(ad_checksum, ta[4]);
901 #endif
902 case 4:
903 ad_checksum = xor_block(ad_checksum, ta[3]);
904 case 3:
905 ad_checksum = xor_block(ad_checksum, ta[2]);
906 case 2:
907 ad_checksum = xor_block(ad_checksum, ta[1]);
908 case 1:
909 ad_checksum = xor_block(ad_checksum, ta[0]);
910 }
911 ctx->ad_checksum = ad_checksum;
912 }
913 }
914 }
915
916 /* ----------------------------------------------------------------------- */
917
ae_encrypt(ae_ctx * ctx,const void * nonce,const void * pt,int pt_len,const void * ad,int ad_len,void * ct,void * tag,int final)918 int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
919 int ad_len, void* ct, void* tag, int final) {
920 union {
921 uint32_t u32[4];
922 uint8_t u8[16];
923 block bl;
924 } tmp;
925 block offset, checksum;
926 unsigned i, k;
927 block* ctp = (block*)ct;
928 const block* ptp = (block*)pt;
929
930 /* Non-null nonce means start of new message, init per-message values */
931 if (nonce) {
932 ctx->offset = gen_offset_from_nonce(ctx, nonce);
933 ctx->ad_offset = ctx->checksum = zero_block();
934 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
935 if (ad_len >= 0)
936 ctx->ad_checksum = zero_block();
937 }
938
939 /* Process associated data */
940 if (ad_len > 0)
941 process_ad(ctx, ad, ad_len, final);
942
943 /* Encrypt plaintext data BPI blocks at a time */
944 offset = ctx->offset;
945 checksum = ctx->checksum;
946 i = pt_len / (BPI * 16);
947 if (i) {
948 block oa[BPI];
949 unsigned block_num = ctx->blocks_processed;
950 oa[BPI - 1] = offset;
951 do {
952 block ta[BPI];
953 block_num += BPI;
954 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
955 ta[0] = xor_block(oa[0], ptp[0]);
956 checksum = xor_block(checksum, ptp[0]);
957 oa[1] = xor_block(oa[0], ctx->L[1]);
958 ta[1] = xor_block(oa[1], ptp[1]);
959 checksum = xor_block(checksum, ptp[1]);
960 oa[2] = xor_block(oa[1], ctx->L[0]);
961 ta[2] = xor_block(oa[2], ptp[2]);
962 checksum = xor_block(checksum, ptp[2]);
963 #if BPI == 4
964 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
965 ta[3] = xor_block(oa[3], ptp[3]);
966 checksum = xor_block(checksum, ptp[3]);
967 #elif BPI == 8
968 oa[3] = xor_block(oa[2], ctx->L[2]);
969 ta[3] = xor_block(oa[3], ptp[3]);
970 checksum = xor_block(checksum, ptp[3]);
971 oa[4] = xor_block(oa[1], ctx->L[2]);
972 ta[4] = xor_block(oa[4], ptp[4]);
973 checksum = xor_block(checksum, ptp[4]);
974 oa[5] = xor_block(oa[0], ctx->L[2]);
975 ta[5] = xor_block(oa[5], ptp[5]);
976 checksum = xor_block(checksum, ptp[5]);
977 oa[6] = xor_block(oa[7], ctx->L[2]);
978 ta[6] = xor_block(oa[6], ptp[6]);
979 checksum = xor_block(checksum, ptp[6]);
980 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
981 ta[7] = xor_block(oa[7], ptp[7]);
982 checksum = xor_block(checksum, ptp[7]);
983 #endif
984 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
985 ctp[0] = xor_block(ta[0], oa[0]);
986 ctp[1] = xor_block(ta[1], oa[1]);
987 ctp[2] = xor_block(ta[2], oa[2]);
988 ctp[3] = xor_block(ta[3], oa[3]);
989 #if (BPI == 8)
990 ctp[4] = xor_block(ta[4], oa[4]);
991 ctp[5] = xor_block(ta[5], oa[5]);
992 ctp[6] = xor_block(ta[6], oa[6]);
993 ctp[7] = xor_block(ta[7], oa[7]);
994 #endif
995 ptp += BPI;
996 ctp += BPI;
997 } while (--i);
998 ctx->offset = offset = oa[BPI - 1];
999 ctx->blocks_processed = block_num;
1000 ctx->checksum = checksum;
1001 }
1002
1003 if (final) {
1004 block ta[BPI + 1], oa[BPI];
1005
1006 /* Process remaining plaintext and compute its tag contribution */
1007 unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1008 k = 0; /* How many blocks in ta[] need ECBing */
1009 if (remaining) {
1010 #if (BPI == 8)
1011 if (remaining >= 64) {
1012 oa[0] = xor_block(offset, ctx->L[0]);
1013 ta[0] = xor_block(oa[0], ptp[0]);
1014 checksum = xor_block(checksum, ptp[0]);
1015 oa[1] = xor_block(oa[0], ctx->L[1]);
1016 ta[1] = xor_block(oa[1], ptp[1]);
1017 checksum = xor_block(checksum, ptp[1]);
1018 oa[2] = xor_block(oa[1], ctx->L[0]);
1019 ta[2] = xor_block(oa[2], ptp[2]);
1020 checksum = xor_block(checksum, ptp[2]);
1021 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1022 ta[3] = xor_block(offset, ptp[3]);
1023 checksum = xor_block(checksum, ptp[3]);
1024 remaining -= 64;
1025 k = 4;
1026 }
1027 #endif
1028 if (remaining >= 32) {
1029 oa[k] = xor_block(offset, ctx->L[0]);
1030 ta[k] = xor_block(oa[k], ptp[k]);
1031 checksum = xor_block(checksum, ptp[k]);
1032 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1033 ta[k + 1] = xor_block(offset, ptp[k + 1]);
1034 checksum = xor_block(checksum, ptp[k + 1]);
1035 remaining -= 32;
1036 k += 2;
1037 }
1038 if (remaining >= 16) {
1039 offset = oa[k] = xor_block(offset, ctx->L[0]);
1040 ta[k] = xor_block(offset, ptp[k]);
1041 checksum = xor_block(checksum, ptp[k]);
1042 remaining -= 16;
1043 ++k;
1044 }
1045 if (remaining) {
1046 tmp.bl = zero_block();
1047 memcpy(tmp.u8, ptp + k, remaining);
1048 tmp.u8[remaining] = (unsigned char)0x80u;
1049 checksum = xor_block(checksum, tmp.bl);
1050 ta[k] = offset = xor_block(offset, ctx->Lstar);
1051 ++k;
1052 }
1053 }
1054 offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1055 ta[k] = xor_block(offset, checksum); /* Part of tag gen */
1056 AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1057 offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1058 if (remaining) {
1059 --k;
1060 tmp.bl = xor_block(tmp.bl, ta[k]);
1061 memcpy(ctp + k, tmp.u8, remaining);
1062 }
1063 switch (k) {
1064 #if (BPI == 8)
1065 case 7:
1066 ctp[6] = xor_block(ta[6], oa[6]);
1067 case 6:
1068 ctp[5] = xor_block(ta[5], oa[5]);
1069 case 5:
1070 ctp[4] = xor_block(ta[4], oa[4]);
1071 case 4:
1072 ctp[3] = xor_block(ta[3], oa[3]);
1073 #endif
1074 case 3:
1075 ctp[2] = xor_block(ta[2], oa[2]);
1076 case 2:
1077 ctp[1] = xor_block(ta[1], oa[1]);
1078 case 1:
1079 ctp[0] = xor_block(ta[0], oa[0]);
1080 }
1081
1082 /* Tag is placed at the correct location
1083 */
1084 if (tag) {
1085 #if (OCB_TAG_LEN == 16)
1086 *(block*)tag = offset;
1087 #elif(OCB_TAG_LEN > 0)
1088 memcpy((char*)tag, &offset, OCB_TAG_LEN);
1089 #else
1090 memcpy((char*)tag, &offset, ctx->tag_len);
1091 #endif
1092 } else {
1093 #if (OCB_TAG_LEN > 0)
1094 memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1095 pt_len += OCB_TAG_LEN;
1096 #else
1097 memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1098 pt_len += ctx->tag_len;
1099 #endif
1100 }
1101 }
1102 return (int)pt_len;
1103 }
1104
1105 /* ----------------------------------------------------------------------- */
1106
1107 /* Compare two regions of memory, taking a constant amount of time for a
1108 given buffer size -- under certain assumptions about the compiler
1109 and machine, of course.
1110
1111 Use this to avoid timing side-channel attacks.
1112
1113 Returns 0 for memory regions with equal contents; non-zero otherwise. */
constant_time_memcmp(const void * av,const void * bv,size_t n)1114 static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1115 const uint8_t* a = (const uint8_t*)av;
1116 const uint8_t* b = (const uint8_t*)bv;
1117 uint8_t result = 0;
1118 size_t i;
1119
1120 for (i = 0; i < n; i++) {
1121 result |= *a ^ *b;
1122 a++;
1123 b++;
1124 }
1125
1126 return (int)result;
1127 }
1128
ae_decrypt(ae_ctx * ctx,const void * nonce,const void * ct,int ct_len,const void * ad,int ad_len,void * pt,const void * tag,int final)1129 int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1130 int ad_len, void* pt, const void* tag, int final) {
1131 union {
1132 uint32_t u32[4];
1133 uint8_t u8[16];
1134 block bl;
1135 } tmp;
1136 block offset, checksum;
1137 unsigned i, k;
1138 block* ctp = (block*)ct;
1139 block* ptp = (block*)pt;
1140
1141 /* Reduce ct_len tag bundled in ct */
1142 if ((final) && (!tag))
1143 #if (OCB_TAG_LEN > 0)
1144 ct_len -= OCB_TAG_LEN;
1145 #else
1146 ct_len -= ctx->tag_len;
1147 #endif
1148
1149 /* Non-null nonce means start of new message, init per-message values */
1150 if (nonce) {
1151 ctx->offset = gen_offset_from_nonce(ctx, nonce);
1152 ctx->ad_offset = ctx->checksum = zero_block();
1153 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1154 if (ad_len >= 0)
1155 ctx->ad_checksum = zero_block();
1156 }
1157
1158 /* Process associated data */
1159 if (ad_len > 0)
1160 process_ad(ctx, ad, ad_len, final);
1161
1162 /* Encrypt plaintext data BPI blocks at a time */
1163 offset = ctx->offset;
1164 checksum = ctx->checksum;
1165 i = ct_len / (BPI * 16);
1166 if (i) {
1167 block oa[BPI];
1168 unsigned block_num = ctx->blocks_processed;
1169 oa[BPI - 1] = offset;
1170 do {
1171 block ta[BPI];
1172 block_num += BPI;
1173 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1174 ta[0] = xor_block(oa[0], ctp[0]);
1175 oa[1] = xor_block(oa[0], ctx->L[1]);
1176 ta[1] = xor_block(oa[1], ctp[1]);
1177 oa[2] = xor_block(oa[1], ctx->L[0]);
1178 ta[2] = xor_block(oa[2], ctp[2]);
1179 #if BPI == 4
1180 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1181 ta[3] = xor_block(oa[3], ctp[3]);
1182 #elif BPI == 8
1183 oa[3] = xor_block(oa[2], ctx->L[2]);
1184 ta[3] = xor_block(oa[3], ctp[3]);
1185 oa[4] = xor_block(oa[1], ctx->L[2]);
1186 ta[4] = xor_block(oa[4], ctp[4]);
1187 oa[5] = xor_block(oa[0], ctx->L[2]);
1188 ta[5] = xor_block(oa[5], ctp[5]);
1189 oa[6] = xor_block(oa[7], ctx->L[2]);
1190 ta[6] = xor_block(oa[6], ctp[6]);
1191 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1192 ta[7] = xor_block(oa[7], ctp[7]);
1193 #endif
1194 AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1195 ptp[0] = xor_block(ta[0], oa[0]);
1196 checksum = xor_block(checksum, ptp[0]);
1197 ptp[1] = xor_block(ta[1], oa[1]);
1198 checksum = xor_block(checksum, ptp[1]);
1199 ptp[2] = xor_block(ta[2], oa[2]);
1200 checksum = xor_block(checksum, ptp[2]);
1201 ptp[3] = xor_block(ta[3], oa[3]);
1202 checksum = xor_block(checksum, ptp[3]);
1203 #if (BPI == 8)
1204 ptp[4] = xor_block(ta[4], oa[4]);
1205 checksum = xor_block(checksum, ptp[4]);
1206 ptp[5] = xor_block(ta[5], oa[5]);
1207 checksum = xor_block(checksum, ptp[5]);
1208 ptp[6] = xor_block(ta[6], oa[6]);
1209 checksum = xor_block(checksum, ptp[6]);
1210 ptp[7] = xor_block(ta[7], oa[7]);
1211 checksum = xor_block(checksum, ptp[7]);
1212 #endif
1213 ptp += BPI;
1214 ctp += BPI;
1215 } while (--i);
1216 ctx->offset = offset = oa[BPI - 1];
1217 ctx->blocks_processed = block_num;
1218 ctx->checksum = checksum;
1219 }
1220
1221 if (final) {
1222 block ta[BPI + 1], oa[BPI];
1223
1224 /* Process remaining plaintext and compute its tag contribution */
1225 unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1226 k = 0; /* How many blocks in ta[] need ECBing */
1227 if (remaining) {
1228 #if (BPI == 8)
1229 if (remaining >= 64) {
1230 oa[0] = xor_block(offset, ctx->L[0]);
1231 ta[0] = xor_block(oa[0], ctp[0]);
1232 oa[1] = xor_block(oa[0], ctx->L[1]);
1233 ta[1] = xor_block(oa[1], ctp[1]);
1234 oa[2] = xor_block(oa[1], ctx->L[0]);
1235 ta[2] = xor_block(oa[2], ctp[2]);
1236 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1237 ta[3] = xor_block(offset, ctp[3]);
1238 remaining -= 64;
1239 k = 4;
1240 }
1241 #endif
1242 if (remaining >= 32) {
1243 oa[k] = xor_block(offset, ctx->L[0]);
1244 ta[k] = xor_block(oa[k], ctp[k]);
1245 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1246 ta[k + 1] = xor_block(offset, ctp[k + 1]);
1247 remaining -= 32;
1248 k += 2;
1249 }
1250 if (remaining >= 16) {
1251 offset = oa[k] = xor_block(offset, ctx->L[0]);
1252 ta[k] = xor_block(offset, ctp[k]);
1253 remaining -= 16;
1254 ++k;
1255 }
1256 if (remaining) {
1257 block pad;
1258 offset = xor_block(offset, ctx->Lstar);
1259 AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1260 pad = tmp.bl;
1261 memcpy(tmp.u8, ctp + k, remaining);
1262 tmp.bl = xor_block(tmp.bl, pad);
1263 tmp.u8[remaining] = (unsigned char)0x80u;
1264 memcpy(ptp + k, tmp.u8, remaining);
1265 checksum = xor_block(checksum, tmp.bl);
1266 }
1267 }
1268 AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1269 switch (k) {
1270 #if (BPI == 8)
1271 case 7:
1272 ptp[6] = xor_block(ta[6], oa[6]);
1273 checksum = xor_block(checksum, ptp[6]);
1274 case 6:
1275 ptp[5] = xor_block(ta[5], oa[5]);
1276 checksum = xor_block(checksum, ptp[5]);
1277 case 5:
1278 ptp[4] = xor_block(ta[4], oa[4]);
1279 checksum = xor_block(checksum, ptp[4]);
1280 case 4:
1281 ptp[3] = xor_block(ta[3], oa[3]);
1282 checksum = xor_block(checksum, ptp[3]);
1283 #endif
1284 case 3:
1285 ptp[2] = xor_block(ta[2], oa[2]);
1286 checksum = xor_block(checksum, ptp[2]);
1287 case 2:
1288 ptp[1] = xor_block(ta[1], oa[1]);
1289 checksum = xor_block(checksum, ptp[1]);
1290 case 1:
1291 ptp[0] = xor_block(ta[0], oa[0]);
1292 checksum = xor_block(checksum, ptp[0]);
1293 }
1294
1295 /* Calculate expected tag */
1296 offset = xor_block(offset, ctx->Ldollar);
1297 tmp.bl = xor_block(offset, checksum);
1298 AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1299 tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1300
1301 /* Compare with proposed tag, change ct_len if invalid */
1302 if ((OCB_TAG_LEN == 16) && tag) {
1303 if (unequal_blocks(tmp.bl, *(block*)tag))
1304 ct_len = AE_INVALID;
1305 } else {
1306 #if (OCB_TAG_LEN > 0)
1307 int len = OCB_TAG_LEN;
1308 #else
1309 int len = ctx->tag_len;
1310 #endif
1311 if (tag) {
1312 if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1313 ct_len = AE_INVALID;
1314 } else {
1315 if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1316 ct_len = AE_INVALID;
1317 }
1318 }
1319 }
1320 return ct_len;
1321 }
1322
1323 /* ----------------------------------------------------------------------- */
1324 /* Simple test program */
1325 /* ----------------------------------------------------------------------- */
1326
1327 #if 0
1328
1329 #include <stdio.h>
1330 #include <time.h>
1331
1332 #if __GNUC__
1333 #define ALIGN(n) __attribute__((aligned(n)))
1334 #elif _MSC_VER
1335 #define ALIGN(n) __declspec(align(n))
1336 #else /* Not GNU/Microsoft: delete alignment uses. */
1337 #define ALIGN(n)
1338 #endif
1339
1340 static void pbuf(void *p, unsigned len, const void *s)
1341 {
1342 unsigned i;
1343 if (s)
1344 printf("%s", (char *)s);
1345 for (i = 0; i < len; i++)
1346 printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1347 printf("\n");
1348 }
1349
1350 static void vectors(ae_ctx *ctx, int len)
1351 {
1352 ALIGN(16) char pt[128];
1353 ALIGN(16) char ct[144];
1354 ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1355 int i;
1356 for (i=0; i < 128; i++) pt[i] = i;
1357 i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1358 printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1359 i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1360 printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1361 i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1362 printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1363 }
1364
1365 void validate()
1366 {
1367 ALIGN(16) char pt[1024];
1368 ALIGN(16) char ct[1024];
1369 ALIGN(16) char tag[16];
1370 ALIGN(16) char nonce[12] = {0,};
1371 ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1372 ae_ctx ctx;
1373 char *val_buf, *next;
1374 int i, len;
1375
1376 val_buf = (char *)malloc(22400 + 16);
1377 next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1378
1379 if (0) {
1380 ae_init(&ctx, key, 16, 12, 16);
1381 /* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1382 vectors(&ctx,0);
1383 vectors(&ctx,8);
1384 vectors(&ctx,16);
1385 vectors(&ctx,24);
1386 vectors(&ctx,32);
1387 vectors(&ctx,40);
1388 }
1389
1390 memset(key,0,32);
1391 memset(pt,0,128);
1392 ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1393
1394 /* RFC Vector test */
1395 for (i = 0; i < 128; i++) {
1396 int first = ((i/3)/(BPI*16))*(BPI*16);
1397 int second = first;
1398 int third = i - (first + second);
1399
1400 nonce[11] = i;
1401
1402 if (0) {
1403 ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1404 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1405 next = next+i+OCB_TAG_LEN;
1406
1407 ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1408 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1409 next = next+i+OCB_TAG_LEN;
1410
1411 ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1412 memcpy(next,ct,OCB_TAG_LEN);
1413 next = next+OCB_TAG_LEN;
1414 } else {
1415 ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1416 ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1417 ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1418 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1419 next = next+i+OCB_TAG_LEN;
1420
1421 ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1422 ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1423 ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1424 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1425 next = next+i+OCB_TAG_LEN;
1426
1427 ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1428 ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1429 ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1430 memcpy(next,ct,OCB_TAG_LEN);
1431 next = next+OCB_TAG_LEN;
1432 }
1433
1434 }
1435 nonce[11] = 0;
1436 ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1437 pbuf(tag,OCB_TAG_LEN,0);
1438
1439
1440 /* Encrypt/Decrypt test */
1441 for (i = 0; i < 128; i++) {
1442 int first = ((i/3)/(BPI*16))*(BPI*16);
1443 int second = first;
1444 int third = i - (first + second);
1445
1446 nonce[11] = i%128;
1447
1448 if (1) {
1449 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1450 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1451 len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1452 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1453 if (len != i) { printf("Length error: %d\n", i); return; }
1454 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1455 } else {
1456 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1457 ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1458 ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1459 len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1460 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1461 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1462 }
1463
1464 }
1465 printf("Decrypt: PASS\n");
1466 }
1467
1468 int main()
1469 {
1470 validate();
1471 return 0;
1472 }
1473 #endif
1474
1475 #if USE_AES_NI
1476 char infoString[] = "OCB3 (AES-NI)";
1477 #elif USE_REFERENCE_AES
1478 char infoString[] = "OCB3 (Reference)";
1479 #elif USE_OPENSSL_AES
1480 char infoString[] = "OCB3 (OpenSSL)";
1481 #endif
1482