You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
lokinet/crypto/libntrup/src/avx/rq_rounded.c

262 lines
7.7 KiB
C

#ifdef __AVX__
#include <immintrin.h>
#include "params.h"
#include "crypto_uint32.h"
#include "rq.h"
#define alpha_top _mm256_set1_epi32(0x43380000)
#define alpha _mm256_set1_pd(6755399441055744.0)
#define v10923_16 _mm256_set1_epi16(10923)
#define floor(x) _mm256_floor_pd(x)
void rq_roundencode(unsigned char *c,const modq *f)
{
int i;
__m256i h[50];
for (i = 0;i < 208;i += 16) {
__m256i a0, a1, a2, b0, b1, b2, c0, c1, c2, d0, d1, d2;
__m256i e0, e1, f0, f1, g0, g1;
a0 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *) &f[0]));
a1 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *) &f[8]));
a2 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *) &f[16]));
a0 = _mm256_inserti128_si256(a0,_mm_loadu_si128((__m128i *) &f[24]),1);
a1 = _mm256_inserti128_si256(a1,_mm_loadu_si128((__m128i *) &f[32]),1);
a2 = _mm256_inserti128_si256(a2,_mm_loadu_si128((__m128i *) &f[40]),1);
f += 48;
a0 = _mm256_mulhrs_epi16(a0,v10923_16);
a1 = _mm256_mulhrs_epi16(a1,v10923_16);
a2 = _mm256_mulhrs_epi16(a2,v10923_16);
/* a0: a0 a1 a2 b0 b1 b2 c0 c1 and similar second half */
/* a1: c2 d0 d1 d2 e0 e1 e2 f0 */
/* a2: f1 f2 g0 g1 g2 h0 h1 h2 */
b1 = _mm256_blend_epi16(a2,a0,0xf0);
b1 = _mm256_shuffle_epi32(b1,0x4e);
b0 = _mm256_blend_epi16(a0,a1,0xf0);
b2 = _mm256_blend_epi16(a1,a2,0xf0);
/* XXX: use shufps instead? */
/* b0: a0 a1 a2 b0 e0 e1 e2 f0 */
/* b1: b1 b2 c0 c1 f1 f2 g0 g1 */
/* b2: c2 d0 d1 d2 g2 h0 h1 h2 */
c1 = _mm256_blend_epi16(b2,b0,0xcc);
c1 = _mm256_shuffle_epi32(c1,0xb1);
c0 = _mm256_blend_epi16(b0,b1,0xcc);
c2 = _mm256_blend_epi16(b1,b2,0xcc);
/* c0: a0 a1 c0 c1 e0 e1 g0 g1 */
/* c1: a2 b0 c2 d0 e2 f0 g2 h0 */
/* c2: b1 b2 d1 d2 f1 f2 h1 h2 */
d1 = _mm256_blend_epi16(c2,c0,0xaa);
d1 = _mm256_shufflelo_epi16(d1,0xb1);
d1 = _mm256_shufflehi_epi16(d1,0xb1);
d0 = _mm256_blend_epi16(c0,c1,0xaa);
d2 = _mm256_blend_epi16(c1,c2,0xaa);
/* d0: a0 b0 c0 d0 e0 f0 g0 h0 */
/* d1: a1 b1 c1 d1 e1 f1 g1 h1 */
/* d2: a2 b2 c2 d2 e2 f2 g2 h2 */
d0 = _mm256_add_epi16(d0,_mm256_set1_epi16(765));
d1 = _mm256_add_epi16(d1,_mm256_set1_epi16(765));
d2 = _mm256_add_epi16(d2,_mm256_set1_epi16(765));
/* want bytes of d0 + 1536*d1 + 1536*1536*d2 */
e0 = d0 & _mm256_set1_epi16(0xff);
d0 = _mm256_srli_epi16(d0,8);
/* want e0, d0 + 6*d1 + 6*1536*d2 */
d1 = _mm256_mullo_epi16(d1,_mm256_set1_epi16(6));
d0 = _mm256_add_epi16(d0,d1);
/* want e0, d0 + 6*1536*d2 */
e1 = _mm256_slli_epi16(d0,8);
e0 = _mm256_add_epi16(e0,e1);
d0 = _mm256_srli_epi16(d0,8);
/* want e0, d0 + 36*d2 */
d2 = _mm256_mullo_epi16(d2,_mm256_set1_epi16(36));
e1 = _mm256_add_epi16(d0,d2);
/* want e0, e1 */
/* e0: out0 out1 out4 out5 out8 out9 ... */
/* e1: out2 out3 out6 out7 out10 out11 ... */
f0 = _mm256_unpacklo_epi16(e0,e1);
f1 = _mm256_unpackhi_epi16(e0,e1);
g0 = _mm256_permute2x128_si256(f0,f1,0x20);
g1 = _mm256_permute2x128_si256(f0,f1,0x31);
_mm256_storeu_si256((__m256i *) c,g0);
_mm256_storeu_si256((__m256i *) (c + 32),g1);
c += 64;
}
for (i = 0;i < 9;++i) {
__m256i x = _mm256_loadu_si256((__m256i *) &f[16 * i]);
_mm256_storeu_si256(&h[i],_mm256_mulhrs_epi16(x,v10923_16));
}
f = (const modq *) h;
for (i = 208;i < 253;++i) {
crypto_int32 f0, f1, f2;
f0 = *f++;
f1 = *f++;
f2 = *f++;
f0 += 1806037245;
f1 *= 3;
f2 *= 9;
f0 += f1 << 9;
f0 += f2 << 18;
*(crypto_int32 *) c = f0;
c += 4;
}
{
crypto_int32 f0, f1;
f0 = *f++;
f1 = *f++;
f0 += 1175805;
f1 *= 3;
f0 += f1 << 9;
*c++ = f0; f0 >>= 8;
*c++ = f0; f0 >>= 8;
*c++ = f0;
}
}
void rq_decoderounded(modq *f,const unsigned char *c)
{
crypto_uint32 c0, c1, c2, c3;
crypto_uint32 f0, f1, f2;
int i;
for (i = 0;i < 248;i += 8) {
__m256i abcdefgh, todo[2];
__m256d x, f2, f1, f0;
__m128i if2, if1, if0;
int j;
abcdefgh = _mm256_loadu_si256((__m256i *) c);
c += 32;
todo[0] = _mm256_unpacklo_epi32(abcdefgh,alpha_top);
todo[1] = _mm256_unpackhi_epi32(abcdefgh,alpha_top);
for (j = 0;j < 2;++j) {
x = *(__m256d *) &todo[j];
x -= alpha;
/* x is f0 + f1*1536 + f2*1536^2 */
/* with each f between 0 and 1530 */
f2 = x * _mm256_set1_pd(0.00000042385525173611114052197733521876177320564238470979034900665283203125);
f2 = floor(f2);
x -= f2 * _mm256_set1_pd(2359296.0);
f1 = x * _mm256_set1_pd(0.00065104166666666673894681149903362893383018672466278076171875);
f1 = floor(f1);
x -= f1 * _mm256_set1_pd(1536.0);
f0 = x;
f2 -= _mm256_set1_pd(1531.0) * floor(f2 * _mm256_set1_pd(0.0006531678641410842804659875326933615724556148052215576171875));
f1 -= _mm256_set1_pd(1531.0) * floor(f1 * _mm256_set1_pd(0.0006531678641410842804659875326933615724556148052215576171875));
f0 -= _mm256_set1_pd(1531.0) * floor(f0 * _mm256_set1_pd(0.0006531678641410842804659875326933615724556148052215576171875));
f2 *= _mm256_set1_pd(3.0); f2 -= _mm256_set1_pd(2295.0);
f1 *= _mm256_set1_pd(3.0); f1 -= _mm256_set1_pd(2295.0);
f0 *= _mm256_set1_pd(3.0); f0 -= _mm256_set1_pd(2295.0);
if2 = _mm256_cvtpd_epi32(f2); /* a2 b2 e2 f2 */
if1 = _mm256_cvtpd_epi32(f1); /* a1 b1 e1 f1 */
if0 = _mm256_cvtpd_epi32(f0); /* a0 b0 e0 f0 */
f[6*j + 0] = _mm_extract_epi32(if0,0);
f[6*j + 1] = _mm_extract_epi32(if1,0);
f[6*j + 2] = _mm_extract_epi32(if2,0);
f[6*j + 3] = _mm_extract_epi32(if0,1);
f[6*j + 4] = _mm_extract_epi32(if1,1);
f[6*j + 5] = _mm_extract_epi32(if2,1);
f[6*j + 12] = _mm_extract_epi32(if0,2);
f[6*j + 13] = _mm_extract_epi32(if1,2);
f[6*j + 14] = _mm_extract_epi32(if2,2);
f[6*j + 15] = _mm_extract_epi32(if0,3);
f[6*j + 16] = _mm_extract_epi32(if1,3);
f[6*j + 17] = _mm_extract_epi32(if2,3);
}
f += 24;
}
for (i = 248;i < 253;++i) {
c0 = *c++;
c1 = *c++;
c2 = *c++;
c3 = *c++;
/* f0 + f1*1536 + f2*1536^2 */
/* = c0 + c1*256 + c2*256^2 + c3*256^3 */
/* with each f between 0 and 1530 */
/* f2 = (64/9)c3 + (1/36)c2 + (1/9216)c1 + (1/2359296)c0 - [0,0.99675] */
/* claim: 2^21 f2 < x < 2^21(f2+1) */
/* where x = 14913081*c3 + 58254*c2 + 228*(c1+2) */
/* proof: x - 2^21 f2 = 456 - (8/9)c0 + (4/9)c1 - (2/9)c2 + (1/9)c3 + 2^21 [0,0.99675] */
/* at least 456 - (8/9)255 - (2/9)255 > 0 */
/* at most 456 + (4/9)255 + (1/9)255 + 2^21 0.99675 < 2^21 */
f2 = (14913081*c3 + 58254*c2 + 228*(c1+2)) >> 21;
c2 += c3 << 8;
c2 -= (f2 * 9) << 2;
/* f0 + f1*1536 */
/* = c0 + c1*256 + c2*256^2 */
/* c2 <= 35 = floor((1530+1530*1536)/256^2) */
/* f1 = (128/3)c2 + (1/6)c1 + (1/1536)c0 - (1/1536)f0 */
/* claim: 2^21 f1 < x < 2^21(f1+1) */
/* where x = 89478485*c2 + 349525*c1 + 1365*(c0+1) */
/* proof: x - 2^21 f1 = 1365 - (1/3)c2 - (1/3)c1 - (1/3)c0 + (4096/3)f0 */
/* at least 1365 - (1/3)35 - (1/3)255 - (1/3)255 > 0 */
/* at most 1365 + (4096/3)1530 < 2^21 */
f1 = (89478485*c2 + 349525*c1 + 1365*(c0+1)) >> 21;
c1 += c2 << 8;
c1 -= (f1 * 3) << 1;
c0 += c1 << 8;
f0 = c0;
*f++ = modq_freeze(f0 * 3 + q - qshift);
*f++ = modq_freeze(f1 * 3 + q - qshift);
*f++ = modq_freeze(f2 * 3 + q - qshift);
}
c0 = *c++;
c1 = *c++;
c2 = *c++;
f1 = (89478485*c2 + 349525*c1 + 1365*(c0+1)) >> 21;
c1 += c2 << 8;
c1 -= (f1 * 3) << 1;
c0 += c1 << 8;
f0 = c0;
*f++ = modq_freeze(f0 * 3 + q - qshift);
*f++ = modq_freeze(f1 * 3 + q - qshift);
*f++ = 0;
*f++ = 0;
*f++ = 0;
*f++ = 0;
*f++ = 0;
*f++ = 0;
*f++ = 0;
}
#endif