From eabe408e2b03bae096f7cd6b48f2e2d2bff9c85c Mon Sep 17 00:00:00 2001 From: Matthew Barr Date: Tue, 27 Sep 2016 16:01:08 +1000 Subject: [PATCH] avx512: shufti --- src/nfa/shufti.c | 351 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 349 insertions(+), 2 deletions(-) diff --git a/src/nfa/shufti.c b/src/nfa/shufti.c index dda5060f..390b6510 100644 --- a/src/nfa/shufti.c +++ b/src/nfa/shufti.c @@ -360,7 +360,8 @@ const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi, return buf_end; } -#else // AVX2 - 256 wide shuftis +#elif !defined(HAVE_AVX512) +// AVX2 - 256 wide shuftis #ifdef DEBUG DUMP_MSK(256) @@ -389,9 +390,11 @@ u32 block(m256 mask_lo, m256 mask_hi, m256 chars, const m256 low4bits, static really_inline const u8 *firstMatch(const u8 *buf, u32 z) { + DEBUG_PRINTF("z 0x%08x\n", z); if (unlikely(z != 0xffffffff)) { u32 pos = ctz32(~z); assert(pos < 32); + DEBUG_PRINTF("match @ pos %u\n", pos); return buf + pos; } else { return NULL; // no match @@ -697,6 +700,7 @@ const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi, const u8 *buf, const u8 *buf_end) { /* we should always have at least 16 bytes */ assert(buf_end - buf >= 16); + DEBUG_PRINTF("buf %p len %zu\n", buf, buf_end - buf); if (buf_end - buf < 32) { return shuftiDoubleShort(mask1_lo, mask1_hi, mask2_lo, mask2_hi, buf, @@ -747,4 +751,347 @@ const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi, return buf_end; } -#endif //AVX2 +#else // defined(HAVE_AVX512) + +#ifdef DEBUG +DUMP_MSK(512) +#endif + +static really_inline +u64a block(m512 mask_lo, m512 mask_hi, m512 chars, const m512 low4bits, + const m512 compare) { + m512 c_lo = pshufb_m512(mask_lo, and512(chars, low4bits)); + m512 c_hi = pshufb_m512(mask_hi, + rshift64_m512(andnot512(low4bits, chars), 4)); + m512 t = and512(c_lo, c_hi); + +#ifdef DEBUG + DEBUG_PRINTF(" chars: "); dumpMsk512AsChars(chars); printf("\n"); + DEBUG_PRINTF(" char: "); dumpMsk512(chars); printf("\n"); + DEBUG_PRINTF(" c_lo: "); dumpMsk512(c_lo); printf("\n"); + DEBUG_PRINTF(" c_hi: "); dumpMsk512(c_hi); printf("\n"); + DEBUG_PRINTF(" t: "); dumpMsk512(t); printf("\n"); +#endif + + return eq512mask(t, compare); +} +static really_inline +const u8 *firstMatch64(const u8 *buf, u64a z) { + DEBUG_PRINTF("z 0x%016llx\n", z); + if (unlikely(z != ~0ULL)) { + u32 pos = ctz64(~z); + DEBUG_PRINTF("match @ pos %u\n", pos); + assert(pos < 64); + return buf + pos; + } else { + return NULL; // no match + } +} + +static really_inline +const u8 *fwdBlock512(m512 mask_lo, m512 mask_hi, m512 chars, const u8 *buf, + const m512 low4bits, const m512 zeroes) { + u64a z = block(mask_lo, mask_hi, chars, low4bits, zeroes); + + return firstMatch64(buf, z); +} + +static really_inline +const u8 *shortShufti512(m512 mask_lo, m512 mask_hi, const u8 *buf, + const u8 *buf_end, const m512 low4bits, + const m512 zeroes) { + DEBUG_PRINTF("short shufti %p len %zu\n", buf, buf_end - buf); + uintptr_t len = buf_end - buf; + assert(len <= 64); + + // load mask + u64a k = (~0ULL) >> (64 - len); + DEBUG_PRINTF("load mask 0x%016llx\n", k); + + m512 chars = loadu_maskz_m512(k, buf); + + u64a z = block(mask_lo, mask_hi, chars, low4bits, zeroes); + + // reuse the load mask to indicate valid bytes + return firstMatch64(buf, z | ~k); +} + +/* takes 128 bit masks, but operates on 512 bits of data */ +const u8 *shuftiExec(m128 mask_lo, m128 mask_hi, const u8 *buf, + const u8 *buf_end) { + assert(buf && buf_end); + assert(buf < buf_end); + DEBUG_PRINTF("shufti %p len %zu\n", buf, buf_end - buf); + DEBUG_PRINTF("b %s\n", buf); + + const m512 low4bits = set64x8(0xf); + const m512 zeroes = zeroes512(); + const m512 wide_mask_lo = set4x128(mask_lo); + const m512 wide_mask_hi = set4x128(mask_hi); + const u8 *rv; + + // small cases. + if (buf_end - buf <= 64) { + rv = shortShufti512(wide_mask_lo, wide_mask_hi, buf, buf_end, low4bits, + zeroes); + return rv ? rv : buf_end; + } + + assert(buf_end - buf >= 64); + + // Preconditioning: most of the time our buffer won't be aligned. + if ((uintptr_t)buf % 64) { + rv = shortShufti512(wide_mask_lo, wide_mask_hi, buf, + ROUNDUP_PTR(buf, 64), low4bits, zeroes); + if (rv) { + return rv; + } + buf = ROUNDUP_PTR(buf, 64); + } + + const u8 *last_block = ROUNDDOWN_PTR(buf_end, 64); + while (buf < last_block) { + m512 lchars = load512(buf); + rv = fwdBlock512(wide_mask_lo, wide_mask_hi, lchars, buf, low4bits, + zeroes); + if (rv) { + return rv; + } + buf += 64; + } + + if (buf == buf_end) { + goto done; + } + + // Use an unaligned load to mop up the last 64 bytes and get an accurate + // picture to buf_end. + assert(buf <= buf_end && buf >= buf_end - 64); + m512 chars = loadu512(buf_end - 64); + rv = fwdBlock512(wide_mask_lo, wide_mask_hi, chars, buf_end - 64, low4bits, + zeroes); + if (rv) { + return rv; + } +done: + return buf_end; +} + +static really_inline +const u8 *lastMatch64(const u8 *buf, u64a z) { + DEBUG_PRINTF("z 0x%016llx\n", z); + if (unlikely(z != ~0ULL)) { + u32 pos = clz64(~z); + DEBUG_PRINTF("buf=%p, pos=%u\n", buf, pos); + return buf + (63 - pos); + } else { + return NULL; // no match + } +} + +static really_inline +const u8 *rshortShufti512(m512 mask_lo, m512 mask_hi, const u8 *buf, + const u8 *buf_end, const m512 low4bits, + const m512 zeroes) { + DEBUG_PRINTF("short %p len %zu\n", buf, buf_end - buf); + uintptr_t len = buf_end - buf; + assert(len <= 64); + + // load mask + u64a k = (~0ULL) >> (64 - len); + DEBUG_PRINTF("load mask 0x%016llx\n", k); + + m512 chars = loadu_maskz_m512(k, buf); + + u64a z = block(mask_lo, mask_hi, chars, low4bits, zeroes); + + // reuse the load mask to indicate valid bytes + return lastMatch64(buf, z | ~k); +} + +static really_inline +const u8 *revBlock512(m512 mask_lo, m512 mask_hi, m512 chars, const u8 *buf, + const m512 low4bits, const m512 zeroes) { + m512 c_lo = pshufb_m512(mask_lo, and512(chars, low4bits)); + m512 c_hi = pshufb_m512(mask_hi, + rshift64_m512(andnot512(low4bits, chars), 4)); + m512 t = and512(c_lo, c_hi); + +#ifdef DEBUG + DEBUG_PRINTF(" chars: "); dumpMsk512AsChars(chars); printf("\n"); + DEBUG_PRINTF(" char: "); dumpMsk512(chars); printf("\n"); + DEBUG_PRINTF(" c_lo: "); dumpMsk512(c_lo); printf("\n"); + DEBUG_PRINTF(" c_hi: "); dumpMsk512(c_hi); printf("\n"); + DEBUG_PRINTF(" t: "); dumpMsk512(t); printf("\n"); +#endif + + u64a z = eq512mask(t, zeroes); + return lastMatch64(buf, z); +} + +/* takes 128 bit masks, but operates on 512 bits of data */ +const u8 *rshuftiExec(m128 mask_lo, m128 mask_hi, const u8 *buf, + const u8 *buf_end) { + DEBUG_PRINTF("buf %p buf_end %p\n", buf, buf_end); + assert(buf && buf_end); + assert(buf < buf_end); + + const m512 low4bits = set64x8(0xf); + const m512 zeroes = zeroes512(); + const m512 wide_mask_lo = set4x128(mask_lo); + const m512 wide_mask_hi = set4x128(mask_hi); + const u8 *rv; + + if (buf_end - buf < 64) { + rv = rshortShufti512(wide_mask_lo, wide_mask_hi, buf, buf_end, low4bits, + zeroes); + return rv ? rv : buf - 1; + } + + if (ROUNDDOWN_PTR(buf_end, 64) != buf_end) { + // peel off unaligned portion + assert(buf_end - buf >= 64); + DEBUG_PRINTF("start\n"); + rv = rshortShufti512(wide_mask_lo, wide_mask_hi, + ROUNDDOWN_PTR(buf_end, 64), buf_end, low4bits, + zeroes); + if (rv) { + return rv; + } + buf_end = ROUNDDOWN_PTR(buf_end, 64); + } + + const u8 *last_block = ROUNDUP_PTR(buf, 64); + while (buf_end > last_block) { + buf_end -= 64; + m512 lchars = load512(buf_end); + rv = revBlock512(wide_mask_lo, wide_mask_hi, lchars, buf_end, low4bits, + zeroes); + if (rv) { + return rv; + } + } + if (buf_end == buf) { + goto done; + } + // Use an unaligned load to mop up the last 64 bytes and get an accurate + // picture to buf. + m512 chars = loadu512(buf); + rv = revBlock512(wide_mask_lo, wide_mask_hi, chars, buf, low4bits, zeroes); + if (rv) { + return rv; + } +done: + return buf - 1; +} + +static really_inline +const u8 *fwdBlock2(m512 mask1_lo, m512 mask1_hi, m512 mask2_lo, m512 mask2_hi, + m512 chars, const u8 *buf, const m512 low4bits, + const m512 ones, __mmask64 k) { + DEBUG_PRINTF("buf %p %.64s\n", buf, buf); + m512 chars_lo = and512(chars, low4bits); + m512 chars_hi = rshift64_m512(andnot512(low4bits, chars), 4); + m512 c_lo = maskz_pshufb_m512(k, mask1_lo, chars_lo); + m512 c_hi = maskz_pshufb_m512(k, mask1_hi, chars_hi); + m512 t = or512(c_lo, c_hi); + +#ifdef DEBUG + DEBUG_PRINTF(" chars: "); dumpMsk512AsChars(chars); printf("\n"); + DEBUG_PRINTF(" char: "); dumpMsk512(chars); printf("\n"); + DEBUG_PRINTF(" c_lo: "); dumpMsk512(c_lo); printf("\n"); + DEBUG_PRINTF(" c_hi: "); dumpMsk512(c_hi); printf("\n"); + DEBUG_PRINTF(" t: "); dumpMsk512(t); printf("\n"); +#endif + + m512 c2_lo = maskz_pshufb_m512(k, mask2_lo, chars_lo); + m512 c2_hi = maskz_pshufb_m512(k, mask2_hi, chars_hi); + m512 t2 = or512(t, rshift128_m512(or512(c2_lo, c2_hi), 1)); + +#ifdef DEBUG + DEBUG_PRINTF(" c2_lo: "); dumpMsk512(c2_lo); printf("\n"); + DEBUG_PRINTF(" c2_hi: "); dumpMsk512(c2_hi); printf("\n"); + DEBUG_PRINTF(" t2: "); dumpMsk512(t2); printf("\n"); +#endif + u64a z = eq512mask(t2, ones); + + return firstMatch64(buf, z | ~k); +} + +static really_inline +const u8 *shortDoubleShufti512(m512 mask1_lo, m512 mask1_hi, m512 mask2_lo, + m512 mask2_hi, const u8 *buf, const u8 *buf_end, + const m512 low4bits, const m512 ones) { + DEBUG_PRINTF("short %p len %zu\n", buf, buf_end - buf); + uintptr_t len = buf_end - buf; + assert(len <= 64); + + u64a k = (~0ULL) >> (64 - len); + DEBUG_PRINTF("load mask 0x%016llx\n", k); + + m512 chars = loadu_mask_m512(ones, k, buf); + + const u8 *rv = fwdBlock2(mask1_lo, mask1_hi, mask2_lo, mask2_hi, chars, buf, + low4bits, ones, k); + + return rv; +} + +/* takes 128 bit masks, but operates on 512 bits of data */ +const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi, + m128 mask2_lo, m128 mask2_hi, + const u8 *buf, const u8 *buf_end) { + /* we should always have at least 16 bytes */ + assert(buf_end - buf >= 16); + DEBUG_PRINTF("buf %p len %zu\n", buf, buf_end - buf); + + const m512 ones = ones512(); + const m512 low4bits = set64x8(0xf); + const m512 wide_mask1_lo = set4x128(mask1_lo); + const m512 wide_mask1_hi = set4x128(mask1_hi); + const m512 wide_mask2_lo = set4x128(mask2_lo); + const m512 wide_mask2_hi = set4x128(mask2_hi); + const u8 *rv; + + if (buf_end - buf <= 64) { + rv = shortDoubleShufti512(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, + wide_mask2_hi, buf, buf_end, low4bits, ones); + DEBUG_PRINTF("rv %p\n", rv); + return rv ? rv : buf_end; + } + + // Preconditioning: most of the time our buffer won't be aligned. + if ((uintptr_t)buf % 64) { + rv = shortDoubleShufti512(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, + wide_mask2_hi, buf, ROUNDUP_PTR(buf, 64), + low4bits, ones); + if (rv) { + return rv; + } + + buf = ROUNDUP_PTR(buf, 64); + } + + const u8 *last_block = buf_end - 64; + while (buf < last_block) { + m512 lchars = load512(buf); + rv = fwdBlock2(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, + wide_mask2_hi, lchars, buf, low4bits, ones, ~0); + if (rv) { + return rv; + } + buf += 64; + } + + // Use an unaligned load to mop up the last 64 bytes and get an accurate + // picture to buf_end. + m512 chars = loadu512(buf_end - 64); + rv = fwdBlock2(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, + chars, buf_end - 64, low4bits, ones, ~0); + if (rv) { + return rv; + } + + return buf_end; +} +#endif