diff --git a/src/nfa/truffle.c b/src/nfa/truffle.c index 331ae6d6..d31b1a56 100644 --- a/src/nfa/truffle.c +++ b/src/nfa/truffle.c @@ -231,7 +231,7 @@ const u8 *rtruffleExec(m128 shuf_mask_lo_highclear, return buf - 1; } -#else +#elif !defined(HAVE_AVX512) // AVX2 @@ -425,4 +425,184 @@ const u8 *rtruffleExec(m128 shuf_mask_lo_highclear, return buf - 1; } +#else // AVX512 + +static really_inline +const u8 *lastMatch(const u8 *buf, u64a z) { + if (unlikely(z != ~0ULL)) { + u64a pos = clz64(~z); + assert(pos < 64); + return buf + (63 - pos); + } + + return NULL; // no match +} + +static really_inline +const u8 *firstMatch(const u8 *buf, u64a z) { + if (unlikely(z != ~0ULL)) { + u64a pos = ctz64(~z); + assert(pos < 64); + DEBUG_PRINTF("pos %llu\n", pos); + return buf + pos; + } + + return NULL; // no match +} + +static really_inline +u64a block(m512 shuf_mask_lo_highclear, m512 shuf_mask_lo_highset, m512 v) { + m512 highconst = set64x8(0x80); + m512 shuf_mask_hi = set8x64(0x8040201008040201); + + // and now do the real work + m512 shuf1 = pshufb_m512(shuf_mask_lo_highclear, v); + m512 t1 = xor512(v, highconst); + m512 shuf2 = pshufb_m512(shuf_mask_lo_highset, t1); + m512 t2 = andnot512(highconst, rshift64_m512(v, 4)); + m512 shuf3 = pshufb_m512(shuf_mask_hi, t2); + m512 tmp = and512(or512(shuf1, shuf2), shuf3); + u64a z = eq512mask(tmp, zeroes512()); + + return z; +} + +static really_inline +const u8 *truffleMini(m512 shuf_mask_lo_highclear, m512 shuf_mask_lo_highset, + const u8 *buf, const u8 *buf_end) { + uintptr_t len = buf_end - buf; + assert(len <= 64); + + __mmask64 mask = (~0ULL) >> (64 - len); + + m512 chars = loadu_maskz_m512(mask, buf); + + u64a z = block(shuf_mask_lo_highclear, shuf_mask_lo_highset, chars); + + const u8 *rv = firstMatch(buf, z | ~mask); + + return rv; +} + +static really_inline +const u8 *fwdBlock(m512 shuf_mask_lo_highclear, m512 shuf_mask_lo_highset, + m512 v, const u8 *buf) { + u64a z = block(shuf_mask_lo_highclear, shuf_mask_lo_highset, v); + return firstMatch(buf, z); +} + +static really_inline +const u8 *revBlock(m512 shuf_mask_lo_highclear, m512 shuf_mask_lo_highset, + m512 v, const u8 *buf) { + u64a z = block(shuf_mask_lo_highclear, shuf_mask_lo_highset, v); + return lastMatch(buf, z); +} + +const u8 *truffleExec(m128 shuf_mask_lo_highclear, m128 shuf_mask_lo_highset, + const u8 *buf, const u8 *buf_end) { + DEBUG_PRINTF("len %zu\n", buf_end - buf); + const m512 wide_clear = set4x128(shuf_mask_lo_highclear); + const m512 wide_set = set4x128(shuf_mask_lo_highset); + + assert(buf && buf_end); + assert(buf < buf_end); + const u8 *rv; + + if (buf_end - buf <= 64) { + rv = truffleMini(wide_clear, wide_set, buf, buf_end); + return rv ? rv : buf_end; + } + + assert(buf_end - buf >= 64); + if ((uintptr_t)buf % 64) { + // Preconditioning: most of the time our buffer won't be aligned. + rv = truffleMini(wide_clear, wide_set, buf, ROUNDUP_PTR(buf, 64)); + if (rv) { + return rv; + } + buf = ROUNDUP_PTR(buf, 64); + } + const u8 *last_block = buf_end - 64; + while (buf < last_block) { + m512 lchars = load512(buf); + rv = fwdBlock(wide_clear, wide_set, lchars, buf); + if (rv) { + return rv; + } + buf += 64; + } + + // Use an unaligned load to mop up the last 64 bytes and get an accurate + // picture to buf_end. + assert(buf <= buf_end && buf >= buf_end - 64); + m512 chars = loadu512(buf_end - 64); + rv = fwdBlock(wide_clear, wide_set, chars, buf_end - 64); + if (rv) { + return rv; + } + return buf_end; +} + +static really_inline +const u8 *truffleRevMini(m512 shuf_mask_lo_highclear, m512 shuf_mask_lo_highset, + const u8 *buf, const u8 *buf_end) { + uintptr_t len = buf_end - buf; + assert(len < 64); + + __mmask64 mask = (~0ULL) >> (64 - len); + m512 chars = loadu_maskz_m512(mask, buf); + u64a z = block(shuf_mask_lo_highclear, shuf_mask_lo_highset, chars); + DEBUG_PRINTF("mask 0x%016llx z 0x%016llx\n", mask, z); + const u8 *rv = lastMatch(buf, z | ~mask); + + if (rv) { + return rv; + } + return buf - 1; +} + +const u8 *rtruffleExec(m128 shuf_mask_lo_highclear, m128 shuf_mask_lo_highset, + const u8 *buf, const u8 *buf_end) { + const m512 wide_clear = set4x128(shuf_mask_lo_highclear); + const m512 wide_set = set4x128(shuf_mask_lo_highset); + assert(buf && buf_end); + assert(buf < buf_end); + const u8 *rv; + + DEBUG_PRINTF("len %zu\n", buf_end - buf); + + if (buf_end - buf < 64) { + return truffleRevMini(wide_clear, wide_set, buf, buf_end); + } + + assert(buf_end - buf >= 64); + + // Preconditioning: most of the time our buffer won't be aligned. + m512 chars = loadu512(buf_end - 64); + rv = revBlock(wide_clear, wide_set, chars, buf_end - 64); + if (rv) { + return rv; + } + buf_end = (const u8 *)ROUNDDOWN_N((uintptr_t)buf_end, 64); + + const u8 *last_block = buf + 64; + while (buf_end > last_block) { + buf_end -= 64; + m512 lchars = load512(buf_end); + rv = revBlock(wide_clear, wide_set, lchars, buf_end); + if (rv) { + return rv; + } + } + + // Use an unaligned load to mop up the last 64 bytes and get an accurate + // picture to buf_end. + chars = loadu512(buf); + rv = revBlock(wide_clear, wide_set, chars, buf); + if (rv) { + return rv; + } + return buf - 1; +} + #endif