avx512: add basic functions to simd_utils

Extends the m512 type to use avx512 and also changes required
for limex.
This commit is contained in:
Matthew Barr
2016-07-20 11:31:34 +10:00
parent fedd48489f
commit 8a56d16d57
11 changed files with 258 additions and 53 deletions

View File

@@ -240,7 +240,7 @@ extern const u8 simd_onebit_masks[];
static really_inline
m128 mask1bit128(unsigned int n) {
assert(n < sizeof(m128) * 8);
u32 mask_idx = ((n % 8) * 64) + 31;
u32 mask_idx = ((n % 8) * 64) + 95;
mask_idx -= n / 8;
return loadu128(&simd_onebit_masks[mask_idx]);
}
@@ -290,6 +290,18 @@ m256 vpshufb(m256 a, m256 b) {
#endif
}
#if defined(HAVE_AVX512)
static really_inline
m512 pshufb_m512(m512 a, m512 b) {
return _mm512_shuffle_epi8(a, b);
}
static really_inline
m512 maskz_pshufb_m512(__mmask64 k, m512 a, m512 b) {
return _mm512_maskz_shuffle_epi8(k, a, b);
}
#endif
static really_inline
m128 variable_byte_shift_m128(m128 in, s32 amount) {
assert(amount >= -16 && amount <= 16);
@@ -592,7 +604,7 @@ m256 loadbytes256(const void *ptr, unsigned int n) {
static really_inline
m256 mask1bit256(unsigned int n) {
assert(n < sizeof(m256) * 8);
u32 mask_idx = ((n % 8) * 64) + 31;
u32 mask_idx = ((n % 8) * 64) + 95;
mask_idx -= n / 8;
return loadu256(&simd_onebit_masks[mask_idx]);
}
@@ -902,41 +914,110 @@ char testbit384(m384 val, unsigned int n) {
**** 512-bit Primitives
****/
static really_inline m512 and512(m512 a, m512 b) {
#define eq512mask(a, b) _mm512_cmpeq_epi8_mask((a), (b))
#define masked_eq512mask(k, a, b) _mm512_mask_cmpeq_epi8_mask((k), (a), (b))
static really_inline
m512 zeroes512(void) {
#if defined(HAVE_AVX512)
return _mm512_setzero_si512();
#else
m512 rv = {zeroes256(), zeroes256()};
return rv;
#endif
}
static really_inline
m512 ones512(void) {
#if defined(HAVE_AVX512)
return _mm512_set1_epi8(0xFF);
//return _mm512_xor_si512(_mm512_setzero_si512(), _mm512_setzero_si512());
#else
m512 rv = {ones256(), ones256()};
return rv;
#endif
}
#if defined(HAVE_AVX512)
static really_inline
m512 set64x8(u8 a) {
return _mm512_set1_epi8(a);
}
static really_inline
m512 set8x64(u64a a) {
return _mm512_set1_epi64(a);
}
static really_inline
m512 set4x128(m128 a) {
return _mm512_broadcast_i32x4(a);
}
#endif
static really_inline
m512 and512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return _mm512_and_si512(a, b);
#else
m512 rv;
rv.lo = and256(a.lo, b.lo);
rv.hi = and256(a.hi, b.hi);
return rv;
#endif
}
static really_inline m512 or512(m512 a, m512 b) {
static really_inline
m512 or512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return _mm512_or_si512(a, b);
#else
m512 rv;
rv.lo = or256(a.lo, b.lo);
rv.hi = or256(a.hi, b.hi);
return rv;
#endif
}
static really_inline m512 xor512(m512 a, m512 b) {
static really_inline
m512 xor512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return _mm512_xor_si512(a, b);
#else
m512 rv;
rv.lo = xor256(a.lo, b.lo);
rv.hi = xor256(a.hi, b.hi);
return rv;
#endif
}
static really_inline m512 not512(m512 a) {
static really_inline
m512 not512(m512 a) {
#if defined(HAVE_AVX512)
return _mm512_xor_si512(a, ones512());
#else
m512 rv;
rv.lo = not256(a.lo);
rv.hi = not256(a.hi);
return rv;
#endif
}
static really_inline m512 andnot512(m512 a, m512 b) {
static really_inline
m512 andnot512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return _mm512_andnot_si512(a, b);
#else
m512 rv;
rv.lo = andnot256(a.lo, b.lo);
rv.hi = andnot256(a.hi, b.hi);
return rv;
#endif
}
#if defined(HAVE_AVX512)
#define lshift64_m512(a, b) _mm512_slli_epi64((a), b)
#else
// The shift amount is an immediate
static really_really_inline
m512 lshift64_m512(m512 a, unsigned b) {
@@ -945,29 +1026,37 @@ m512 lshift64_m512(m512 a, unsigned b) {
rv.hi = lshift64_m256(a.hi, b);
return rv;
}
#endif
static really_inline m512 zeroes512(void) {
m512 rv = {zeroes256(), zeroes256()};
return rv;
}
#if defined(HAVE_AVX512)
#define rshift64_m512(a, b) _mm512_srli_epi64((a), (b))
#define rshift128_m512(a, count_immed) _mm512_bsrli_epi128(a, count_immed)
#endif
static really_inline m512 ones512(void) {
m512 rv = {ones256(), ones256()};
return rv;
}
#if !defined(_MM_CMPINT_NE)
#define _MM_CMPINT_NE 0x4
#endif
static really_inline int diff512(m512 a, m512 b) {
static really_inline
int diff512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return !!_mm512_cmp_epi8_mask(a, b, _MM_CMPINT_NE);
#else
return diff256(a.lo, b.lo) || diff256(a.hi, b.hi);
#endif
}
static really_inline int isnonzero512(m512 a) {
#if !defined(HAVE_AVX2)
static really_inline
int isnonzero512(m512 a) {
#if defined(HAVE_AVX512)
return diff512(a, zeroes512());
#elif defined(HAVE_AVX2)
m256 x = or256(a.lo, a.hi);
return !!diff256(x, zeroes256());
#else
m128 x = or128(a.lo.lo, a.lo.hi);
m128 y = or128(a.hi.lo, a.hi.hi);
return isnonzero128(or128(x, y));
#else
m256 x = or256(a.lo, a.hi);
return !!diff256(x, zeroes256());
#endif
}
@@ -975,8 +1064,11 @@ static really_inline int isnonzero512(m512 a) {
* "Rich" version of diff512(). Takes two vectors a and b and returns a 16-bit
* mask indicating which 32-bit words contain differences.
*/
static really_inline u32 diffrich512(m512 a, m512 b) {
#if defined(HAVE_AVX2)
static really_inline
u32 diffrich512(m512 a, m512 b) {
#if defined(HAVE_AVX512)
return _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_NE);
#elif defined(HAVE_AVX2)
return diffrich256(a.lo, b.lo) | (diffrich256(a.hi, b.hi) << 8);
#else
a.lo.lo = _mm_cmpeq_epi32(a.lo.lo, b.lo.lo);
@@ -993,22 +1085,32 @@ static really_inline u32 diffrich512(m512 a, m512 b) {
* "Rich" version of diffrich(), 64-bit variant. Takes two vectors a and b and
* returns a 16-bit mask indicating which 64-bit words contain differences.
*/
static really_inline u32 diffrich64_512(m512 a, m512 b) {
static really_inline
u32 diffrich64_512(m512 a, m512 b) {
//TODO: cmp_epi64?
u32 d = diffrich512(a, b);
return (d | (d >> 1)) & 0x55555555;
}
// aligned load
static really_inline m512 load512(const void *ptr) {
static really_inline
m512 load512(const void *ptr) {
#if defined(HAVE_AVX512)
return _mm512_load_si512(ptr);
#else
assert(ISALIGNED_N(ptr, alignof(m256)));
m512 rv = { load256(ptr), load256((const char *)ptr + 32) };
return rv;
#endif
}
// aligned store
static really_inline void store512(void *ptr, m512 a) {
assert(ISALIGNED_N(ptr, alignof(m256)));
#if defined(HAVE_AVX2)
static really_inline
void store512(void *ptr, m512 a) {
assert(ISALIGNED_N(ptr, alignof(m512)));
#if defined(HAVE_AVX512)
return _mm512_store_si512(ptr, a);
#elif defined(HAVE_AVX2)
m512 *x = (m512 *)ptr;
store256(&x->lo, a.lo);
store256(&x->hi, a.hi);
@@ -1019,11 +1121,28 @@ static really_inline void store512(void *ptr, m512 a) {
}
// unaligned load
static really_inline m512 loadu512(const void *ptr) {
static really_inline
m512 loadu512(const void *ptr) {
#if defined(HAVE_AVX512)
return _mm512_loadu_si512(ptr);
#else
m512 rv = { loadu256(ptr), loadu256((const char *)ptr + 32) };
return rv;
#endif
}
#if defined(HAVE_AVX512)
static really_inline
m512 loadu_maskz_m512(__mmask64 k, const void *ptr) {
return _mm512_maskz_loadu_epi8(k, ptr);
}
static really_inline
m512 loadu_mask_m512(m512 src, __mmask64 k, const void *ptr) {
return _mm512_mask_loadu_epi8(src, k, ptr);
}
#endif
// packed unaligned store of first N bytes
static really_inline
void storebytes512(void *ptr, m512 a, unsigned int n) {
@@ -1040,6 +1159,14 @@ m512 loadbytes512(const void *ptr, unsigned int n) {
return a;
}
static really_inline
m512 mask1bit512(unsigned int n) {
assert(n < sizeof(m512) * 8);
u32 mask_idx = ((n % 8) * 64) + 95;
mask_idx -= n / 8;
return loadu512(&simd_onebit_masks[mask_idx]);
}
// switches on bit N in the given vector.
static really_inline
void setbit512(m512 *ptr, unsigned int n) {
@@ -1056,6 +1183,8 @@ void setbit512(m512 *ptr, unsigned int n) {
sub = &ptr->hi.hi;
}
setbit128(sub, n % 128);
#elif defined(HAVE_AVX512)
*ptr = or512(mask1bit512(n), *ptr);
#else
m256 *sub;
if (n < 256) {
@@ -1084,6 +1213,8 @@ void clearbit512(m512 *ptr, unsigned int n) {
sub = &ptr->hi.hi;
}
clearbit128(sub, n % 128);
#elif defined(HAVE_AVX512)
*ptr = andnot512(mask1bit512(n), *ptr);
#else
m256 *sub;
if (n < 256) {
@@ -1112,6 +1243,9 @@ char testbit512(m512 val, unsigned int n) {
sub = val.hi.hi;
}
return testbit128(sub, n % 128);
#elif defined(HAVE_AVX512)
const m512 mask = mask1bit512(n);
return !!_mm512_test_epi8_mask(mask, val);
#else
m256 sub;
if (n < 256) {