From 213ac75e0ed547fd72cdfb9240053f56ec6dd5ee Mon Sep 17 00:00:00 2001 From: Yoan Picchi Date: Mon, 6 Jan 2025 17:50:18 +0000 Subject: [PATCH] Fix double shufti reporting false positives Double shufti used to offset one vector, resulting in losing one character at the end of every vector. This was replaced by a magic value indicating a match. This meant that if the first char of a pattern fell on the last char of a vector, double shufti would assume the second character is present and report a match. This patch fixes it by keeping the previous vector and feeding its data to the new one when we shift it, preventing any loss of data. Signed-off-by: Yoan Picchi --- src/nfa/arm/shufti.hpp | 17 +++++-- src/nfa/ppc64el/shufti.hpp | 17 +++++-- src/nfa/shufti_simd.hpp | 75 +++++++++++++++++++++++---- src/nfa/shufti_sve.hpp | 101 +++++++++++++++++++++++++------------ src/nfa/x86/shufti.hpp | 43 ++++++++++++++-- unit/internal/shufti.cpp | 7 ++- 6 files changed, 203 insertions(+), 57 deletions(-) diff --git a/src/nfa/arm/shufti.hpp b/src/nfa/arm/shufti.hpp index e710fd16..97931f4d 100644 --- a/src/nfa/arm/shufti.hpp +++ b/src/nfa/arm/shufti.hpp @@ -46,7 +46,7 @@ const SuperVector blockSingleMask(SuperVector mask_lo, SuperVector mask template static really_inline -SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector chars) { +SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector *inout_t1, SuperVector chars) { const SuperVector low4bits = SuperVector::dup_u8(0xf); SuperVector chars_lo = chars & low4bits; @@ -57,18 +57,25 @@ SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, c1_lo.print8("c1_lo"); SuperVector c1_hi = mask1_hi.template pshufb(chars_hi); c1_hi.print8("c1_hi"); - SuperVector t1 = c1_lo | c1_hi; - t1.print8("t1"); + SuperVector new_t1 = c1_lo | c1_hi; + // t1 is the match mask for the first char of the patterns + new_t1.print8("t1"); SuperVector c2_lo = mask2_lo.template pshufb(chars_lo); c2_lo.print8("c2_lo"); SuperVector c2_hi = mask2_hi.template pshufb(chars_hi); c2_hi.print8("c2_hi"); SuperVector t2 = c2_lo | c2_hi; + // t2 is the match mask for the second char of the patterns t2.print8("t2"); - t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)"); - SuperVector t = t1 | (t2.template vshr_128_imm<1>()); + + // offset t1 so it aligns with t2. The hole created by the offset is filled + // with the last elements of the previous t1 so no info is lost. + // Bits set to 0 lining up indicate a match. + SuperVector t = (new_t1.alignr(*inout_t1, S-1)) | t2; t.print8("t"); + *inout_t1 = new_t1; + return !t.eq(SuperVector::Ones()); } \ No newline at end of file diff --git a/src/nfa/ppc64el/shufti.hpp b/src/nfa/ppc64el/shufti.hpp index dedeb52d..580dbe40 100644 --- a/src/nfa/ppc64el/shufti.hpp +++ b/src/nfa/ppc64el/shufti.hpp @@ -48,7 +48,7 @@ const SuperVector blockSingleMask(SuperVector mask_lo, SuperVector mask template static really_inline -SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector chars) { +SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector *inout_t1, SuperVector chars) { const SuperVector low4bits = SuperVector::dup_u8(0xf); SuperVector chars_lo = chars & low4bits; @@ -59,18 +59,25 @@ SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, c1_lo.print8("c1_lo"); SuperVector c1_hi = mask1_hi.template pshufb(chars_hi); c1_hi.print8("c1_hi"); - SuperVector t1 = c1_lo | c1_hi; - t1.print8("t1"); + SuperVector new_t1 = c1_lo | c1_hi; + // t1 is the match mask for the first char of the patterns + new_t1.print8("t1"); SuperVector c2_lo = mask2_lo.template pshufb(chars_lo); c2_lo.print8("c2_lo"); SuperVector c2_hi = mask2_hi.template pshufb(chars_hi); c2_hi.print8("c2_hi"); SuperVector t2 = c2_lo | c2_hi; + // t2 is the match mask for the second char of the patterns t2.print8("t2"); - t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)"); - SuperVector t = t1 | (t2.template vshr_128_imm<1>()); + + // offset t1 so it aligns with t2. The hole created by the offset is filled + // with the last elements of the previous t1 so no info is lost. + // If bits with value 0 lines up, it indicate a match. + SuperVector t = (new_t1.alignr(*inout_t1, S-1)) | t2; t.print8("t"); + *inout_t1 = new_t1; + return t.eq(SuperVector::Ones()); } diff --git a/src/nfa/shufti_simd.hpp b/src/nfa/shufti_simd.hpp index 1a00b87b..eede8a43 100644 --- a/src/nfa/shufti_simd.hpp +++ b/src/nfa/shufti_simd.hpp @@ -50,7 +50,7 @@ static really_inline const SuperVector blockSingleMask(SuperVector mask_lo, SuperVector mask_hi, SuperVector chars); template static really_inline -SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector chars); +SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector *inout_first_char_mask, SuperVector chars); #if defined(VS_SIMDE_BACKEND) #include "x86/shufti.hpp" @@ -82,11 +82,13 @@ const u8 *revBlock(SuperVector mask_lo, SuperVector mask_hi, SuperVector static really_inline -const u8 *fwdBlockDouble(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector chars, const u8 *buf) { +const u8 *fwdBlockDouble(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector *prev_first_char_mask, SuperVector chars, const u8 *buf) { - SuperVector mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, chars); + SuperVector mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, prev_first_char_mask, chars); - return first_zero_match_inverted(buf, mask); + // By shifting first_char_mask instead of the legacy t2 mask, we would report + // on the second char instead of the first. we offset the buf to compensate. + return first_zero_match_inverted(buf-1, mask); } template @@ -196,6 +198,38 @@ const u8 *rshuftiExecReal(m128 mask_lo, m128 mask_hi, const u8 *buf, const u8 *b return buf - 1; } +// A match on the last char is valid if and only if it match a single char +// pattern, not a char pair. So we manually check the last match with the +// wildcard patterns. +template +static really_inline +const u8 *check_last_byte(SuperVector mask2_lo, SuperVector mask2_hi, + SuperVector mask, uint8_t mask_len, const u8 *buf_end) { + uint8_t last_elem = mask.u.u8[mask_len - 1]; + + SuperVector reduce = mask2_lo | mask2_hi; +#if defined(HAVE_SIMD_512_BITS) + if constexpr (S >= 64) + reduce = reduce | reduce.vshr_512(32); +#endif +#if defined(HAVE_SIMD_256_BITS) + if constexpr (S >= 32) + reduce = reduce | reduce.vshr_256(16); +#endif + reduce = reduce | reduce.vshr_128(8); + reduce = reduce | reduce.vshr_64(32); + reduce = reduce | reduce.vshr_32(16); + reduce = reduce | reduce.vshr_16(8); + uint8_t match_inverted = reduce.u.u8[0] | last_elem; + + // if 0xff, then no match + int match = match_inverted != 0xff; + if(match) { + return buf_end - 1; + } + return NULL; +} + template const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128 mask2_hi, const u8 *buf, const u8 *buf_end) { @@ -216,6 +250,8 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128 __builtin_prefetch(d + 2*64); __builtin_prefetch(d + 3*64); __builtin_prefetch(d + 4*64); + + SuperVector first_char_mask = SuperVector::Ones(); DEBUG_PRINTF("start %p end %p \n", d, buf_end); assert(d < buf_end); if (d + S <= buf_end) { @@ -223,33 +259,54 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128 DEBUG_PRINTF("until aligned %p \n", ROUNDUP_PTR(d, S)); if (!ISALIGNED_N(d, S)) { SuperVector chars = SuperVector::loadu(d); - rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d); + rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d); DEBUG_PRINTF("rv %p \n", rv); if (rv) return rv; d = ROUNDUP_PTR(d, S); + ptrdiff_t offset = d - buf; + first_char_mask.print8("inout_c1"); + if constexpr (S == 16) { + first_char_mask = first_char_mask.vshl_128(S - offset); + } +#ifdef HAVE_SIMD_256_BITS + else if constexpr (S == 32) { + first_char_mask = first_char_mask.vshl_256(S - offset); + } +#endif +#ifdef HAVE_SIMD_512_BITS + else if constexpr (S == 64) { + first_char_mask = first_char_mask.vshl_512(S - offset); + } +#endif + first_char_mask.print8("inout_c1 shifted"); } + first_char_mask = SuperVector::Ones(); while(d + S <= buf_end) { __builtin_prefetch(d + 64); DEBUG_PRINTF("d %p \n", d); SuperVector chars = SuperVector::load(d); - rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d); - if (rv) return rv; + rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d); + if (rv && rv < buf_end - 1) return rv; d += S; } } + ptrdiff_t last_mask_len = S; DEBUG_PRINTF("tail d %p e %p \n", d, buf_end); // finish off tail if (d != buf_end) { SuperVector chars = SuperVector::loadu(d); - rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d); + rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, &first_char_mask, chars, d); DEBUG_PRINTF("rv %p \n", rv); - if (rv && rv < buf_end) return rv; + if (rv && rv < buf_end - 1) return rv; + last_mask_len = buf_end - d; } + rv = check_last_byte(wide_mask2_lo, wide_mask2_hi, first_char_mask, last_mask_len, buf_end); + if (rv) return rv; return buf_end; } diff --git a/src/nfa/shufti_sve.hpp b/src/nfa/shufti_sve.hpp index 76f1e7ad..3e1bc86c 100644 --- a/src/nfa/shufti_sve.hpp +++ b/src/nfa/shufti_sve.hpp @@ -153,7 +153,7 @@ const u8 *rshuftiExec(m128 mask_lo, m128 mask_hi, const u8 *buf, static really_inline svbool_t doubleMatched(svuint8_t mask1_lo, svuint8_t mask1_hi, svuint8_t mask2_lo, svuint8_t mask2_hi, - const u8 *buf, const svbool_t pg) { + svuint8_t* inout_t1, const u8 *buf, const svbool_t pg) { svuint8_t vec = svld1_u8(pg, buf); svuint8_t chars_lo = svand_x(svptrue_b8(), vec, (uint8_t)0xf); @@ -161,38 +161,59 @@ svbool_t doubleMatched(svuint8_t mask1_lo, svuint8_t mask1_hi, svuint8_t c1_lo = svtbl(mask1_lo, chars_lo); svuint8_t c1_hi = svtbl(mask1_hi, chars_hi); - svuint8_t t1 = svorr_x(svptrue_b8(), c1_lo, c1_hi); + svuint8_t new_t1 = svorr_z(svptrue_b8(), c1_lo, c1_hi); svuint8_t c2_lo = svtbl(mask2_lo, chars_lo); svuint8_t c2_hi = svtbl(mask2_hi, chars_hi); - svuint8_t t2 = svext(svorr_z(pg, c2_lo, c2_hi), svdup_u8(0), 1); + svuint8_t t2 = svorr_x(svptrue_b8(), c2_lo, c2_hi); - svuint8_t t = svorr_x(svptrue_b8(), t1, t2); + // shift t1 left by one and feeds in the last element from the previous t1 + uint8_t last_elem = svlastb(svptrue_b8(), *inout_t1); + svuint8_t merged_t1 = svinsr(new_t1, last_elem); + svuint8_t t = svorr_x(svptrue_b8(), merged_t1, t2); + *inout_t1 = new_t1; return svnot_z(svptrue_b8(), svcmpeq(svptrue_b8(), t, (uint8_t)0xff)); } +static really_inline +const u8 *check_last_byte(svuint8_t mask2_lo, svuint8_t mask2_hi, + uint8_t last_elem, const u8 *buf_end) { + uint8_t wild_lo = svorv(svptrue_b8(), mask2_lo); + uint8_t wild_hi = svorv(svptrue_b8(), mask2_hi); + uint8_t match_inverted = wild_lo | wild_hi | last_elem; + int match = match_inverted != 0xff; + if(match) { + return buf_end - 1; + } + return NULL; +} + static really_inline const u8 *dshuftiOnce(svuint8_t mask1_lo, svuint8_t mask1_hi, svuint8_t mask2_lo, svuint8_t mask2_hi, - const u8 *buf, const u8 *buf_end) { + svuint8_t *inout_t1, const u8 *buf, const u8 *buf_end) { DEBUG_PRINTF("start %p end %p\n", buf, buf_end); assert(buf < buf_end); DEBUG_PRINTF("l = %td\n", buf_end - buf); svbool_t pg = svwhilelt_b8_s64(0, buf_end - buf); svbool_t matched = doubleMatched(mask1_lo, mask1_hi, mask2_lo, mask2_hi, - buf, pg); - return accelSearchCheckMatched(buf, matched); + inout_t1, buf, pg); + // doubleMatched return match position of the second char, but here we + // return the position of the first char, hence the buffer offset + return accelSearchCheckMatched(buf - 1, matched); } static really_inline const u8 *dshuftiLoopBody(svuint8_t mask1_lo, svuint8_t mask1_hi, svuint8_t mask2_lo, svuint8_t mask2_hi, - const u8 *buf) { + svuint8_t *inout_t1, const u8 *buf) { DEBUG_PRINTF("start %p end %p\n", buf, buf + svcntb()); svbool_t matched = doubleMatched(mask1_lo, mask1_hi, mask2_lo, mask2_hi, - buf, svptrue_b8()); - return accelSearchCheckMatched(buf, matched); + inout_t1, buf, svptrue_b8()); + // doubleMatched return match position of the second char, but here we + // return the position of the first char, hence the buffer offset + return accelSearchCheckMatched(buf - 1, matched); } static really_inline @@ -200,31 +221,47 @@ const u8 *dshuftiSearch(svuint8_t mask1_lo, svuint8_t mask1_hi, svuint8_t mask2_lo, svuint8_t mask2_hi, const u8 *buf, const u8 *buf_end) { assert(buf < buf_end); + svuint8_t inout_t1 = svdup_u8(0xff); size_t len = buf_end - buf; - if (len <= svcntb()) { - return dshuftiOnce(mask1_lo, mask1_hi, - mask2_lo, mask2_hi, buf, buf_end); - } - // peel off first part to align to the vector size - const u8 *aligned_buf = ROUNDUP_PTR(buf, svcntb_pat(SV_POW2)); - assert(aligned_buf < buf_end); - if (buf != aligned_buf) { - const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi, - mask2_lo, mask2_hi, buf); - if (ptr) return ptr; - } - buf = aligned_buf; - size_t loops = (buf_end - buf) / svcntb(); - DEBUG_PRINTF("loops %zu \n", loops); - for (size_t i = 0; i < loops; i++, buf += svcntb()) { - const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi, - mask2_lo, mask2_hi, buf); - if (ptr) return ptr; + if (len > svcntb()) { + // peel off first part to align to the vector size + const u8 *aligned_buf = ROUNDUP_PTR(buf, svcntb_pat(SV_POW2)); + assert(aligned_buf < buf_end); + if (buf != aligned_buf) { + const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi, mask2_lo, + mask2_hi, &inout_t1, buf); + if (ptr) return ptr; + // The last match in inout won't line up with the next round as we + // use an overlap. We need to set inout according to the last + // unique-searched char. + size_t offset = aligned_buf - buf; + uint8_t last_unique_elem = + svlastb(svwhilelt_b8(0UL, offset), inout_t1); + inout_t1 = svdup_u8(last_unique_elem); + } + buf = aligned_buf; + size_t loops = (buf_end - buf) / svcntb(); + DEBUG_PRINTF("loops %zu \n", loops); + for (size_t i = 0; i < loops; i++, buf += svcntb()) { + const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi, mask2_lo, + mask2_hi, &inout_t1, buf); + if (ptr) return ptr; + } + if (buf == buf_end) { + uint8_t last_elem = svlastb(svptrue_b8(), inout_t1); + return check_last_byte(mask2_lo, mask2_hi, last_elem, buf_end); + } } DEBUG_PRINTF("buf %p buf_end %p \n", buf, buf_end); - return buf == buf_end ? NULL : dshuftiLoopBody(mask1_lo, mask1_hi, - mask2_lo, mask2_hi, - buf_end - svcntb()); + + len = buf_end - buf; + const u8 *ptr = dshuftiOnce(mask1_lo, mask1_hi, + mask2_lo, mask2_hi, &inout_t1, buf, buf_end); + if (ptr) return ptr; + uint8_t last_elem = + svlastb(svwhilelt_b8(0UL, len), inout_t1); + return check_last_byte(mask2_lo, mask2_hi, last_elem, buf_end); + } const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi, diff --git a/src/nfa/x86/shufti.hpp b/src/nfa/x86/shufti.hpp index 6fb34b2f..88aa4904 100644 --- a/src/nfa/x86/shufti.hpp +++ b/src/nfa/x86/shufti.hpp @@ -46,7 +46,7 @@ const SuperVector blockSingleMask(SuperVector mask_lo, SuperVector mask template static really_inline -SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector chars) { +SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, SuperVector mask2_lo, SuperVector mask2_hi, SuperVector *inout_c1, SuperVector chars) { const SuperVector low4bits = SuperVector::dup_u8(0xf); SuperVector chars_lo = chars & low4bits; @@ -57,18 +57,51 @@ SuperVector blockDoubleMask(SuperVector mask1_lo, SuperVector mask1_hi, c1_lo.print8("c1_lo"); SuperVector c1_hi = mask1_hi.pshufb(chars_hi); c1_hi.print8("c1_hi"); - SuperVector c1 = c1_lo | c1_hi; - c1.print8("c1"); + SuperVector new_c1 = c1_lo | c1_hi; + // c1 is the match mask for the first char of the patterns + new_c1.print8("c1"); SuperVector c2_lo = mask2_lo.pshufb(chars_lo); c2_lo.print8("c2_lo"); SuperVector c2_hi = mask2_hi.pshufb(chars_hi); c2_hi.print8("c2_hi"); SuperVector c2 = c2_lo | c2_hi; + // c2 is the match mask for the second char of the patterns c2.print8("c2"); - c2.template vshr_128_imm<1>().print8("c2.vshr_128(1)"); - SuperVector c = c1 | (c2.template vshr_128_imm<1>()); + + // We want to shift the whole vector left by 1 and insert the last element of inout_c1. + // Due to lack of direct instructions to insert, extract and concatenate vectors + // we need to to store and load the vector. + uint8_t tmp_buf[2*S]; + SuperVector offset_c1; + if constexpr (S == 16) { + _mm_storeu_si128(reinterpret_cast(&tmp_buf[0]), inout_c1->u.v128[0]); + _mm_storeu_si128(reinterpret_cast(&tmp_buf[S]), new_c1.u.v128[0]); + offset_c1 = SuperVector(_mm_loadu_si128(reinterpret_cast(&tmp_buf[S-1]))); + } +#ifdef HAVE_AVX2 + else if constexpr (S == 32) { + _mm256_storeu_si256(reinterpret_cast(&tmp_buf[0]), inout_c1->u.v256[0]); + _mm256_storeu_si256(reinterpret_cast(&tmp_buf[S]), new_c1.u.v256[0]); + offset_c1 = SuperVector(_mm256_loadu_si256(reinterpret_cast(&tmp_buf[S-1]))); + } +#endif +#ifdef HAVE_AVX512 + else if constexpr (S == 64) { + _mm512_storeu_si512(reinterpret_cast(&tmp_buf[0]), inout_c1->u.v512[0]); + _mm512_storeu_si512(reinterpret_cast(&tmp_buf[S]), new_c1.u.v512[0]); + offset_c1 = SuperVector(_mm512_load_si512(reinterpret_cast(&tmp_buf[S-1]))); + } +#endif + offset_c1.print8("offset c1"); + + // offset c1 so it aligns with c2. The hole created by the offset is filled + // with the last elements of the previous c1 so no info is lost. + // If bits with value 0 lines up, it indicate a match. + SuperVector c = offset_c1 | c2; c.print8("c"); + *inout_c1 = new_c1; + return c.eq(SuperVector::Ones()); } diff --git a/unit/internal/shufti.cpp b/unit/internal/shufti.cpp index e7d8532f..af4b633f 100644 --- a/unit/internal/shufti.cpp +++ b/unit/internal/shufti.cpp @@ -899,7 +899,12 @@ TEST(DoubleShufti, ExecMatchMixed3) { const u8 *rv = shuftiDoubleExec(lo1, hi1, lo2, hi2, reinterpret_cast(t2), reinterpret_cast(t2) + len); - ASSERT_EQ(reinterpret_cast(&t2[len - i]), rv); + if(i < 2) { + // i=0 is "xy" out of buffer. i=1 is "x" in buffer but not "y" + ASSERT_EQ(reinterpret_cast(t2 + len), rv); + }else { + ASSERT_EQ(reinterpret_cast(&t2[len - i]), rv); + } } }