Fix shufti false positive on vector edge

If we look for a pattern like "ab" and the letter 'a' fall at the end of the vector,
then it was reporting a positive match, regardless of the second letter.
This patch fix this false positive, but slows shufti down by 16%.

Signed-off-by: Yoan Picchi <yoan.picchi@arm.com>
This commit is contained in:
Yoan Picchi 2024-10-28 18:05:18 +00:00
parent 5853cd39d2
commit 76949c215c
5 changed files with 62 additions and 37 deletions

View File

@ -46,7 +46,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask
template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, SuperVector<S> offset_char) {
const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
@ -60,14 +60,18 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
SuperVector<S> t1 = c1_lo | c1_hi;
t1.print8("t1");
SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo);
SuperVector<S> chars_lo2 = offset_char & low4bits;
chars_lo.print8("chars_lo2");
SuperVector<S> chars_hi2 = offset_char.template vshr_64_imm<4>() & low4bits;
chars_hi.print8("chars_hi2");
SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo2);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi);
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi2);
c2_hi.print8("c2_hi");
SuperVector<S> t2 = c2_lo | c2_hi;
t2.print8("t2");
t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)");
SuperVector<S> t = t1 | (t2.template vshr_128_imm<1>());
SuperVector<S> t = t1 | t2;
t.print8("t");
return !t.eq(SuperVector<S>::Ones());

View File

@ -48,7 +48,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask
template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, SuperVector<S> offset_char) {
const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
@ -62,14 +62,18 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
SuperVector<S> t1 = c1_lo | c1_hi;
t1.print8("t1");
SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo);
SuperVector<S> chars_lo2 = offset_char & low4bits;
chars_lo.print8("chars_lo2");
SuperVector<S> chars_hi2 = offset_char.template vshr_64_imm<4>() & low4bits;
chars_hi.print8("chars_hi2");
SuperVector<S> c2_lo = mask2_lo.template pshufb<true>(chars_lo2);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi);
SuperVector<S> c2_hi = mask2_hi.template pshufb<true>(chars_hi2);
c2_hi.print8("c2_hi");
SuperVector<S> t2 = c2_lo | c2_hi;
t2.print8("t2");
t2.template vshr_128_imm<1>().print8("t2.vshr_128(1)");
SuperVector<S> t = t1 | (t2.template vshr_128_imm<1>());
SuperVector<S> t = t1 | t2;
t.print8("t");
return t.eq(SuperVector<S>::Ones());

View File

@ -50,7 +50,7 @@ static really_inline
const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask_hi, SuperVector<S> chars);
template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars);
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, SuperVector<S> offset_chars);
#if defined(VS_SIMDE_BACKEND)
#include "x86/shufti.hpp"
@ -82,9 +82,9 @@ const u8 *revBlock(SuperVector<S> mask_lo, SuperVector<S> mask_hi, SuperVector<S
template <uint16_t S>
static really_inline
const u8 *fwdBlockDouble(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, const u8 *buf) {
const u8 *fwdBlockDouble(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, SuperVector<S> offset_chars, const u8 *buf) {
SuperVector<S> mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, chars);
SuperVector<S> mask = blockDoubleMask(mask1_lo, mask1_hi, mask2_lo, mask2_hi, chars, offset_chars);
return first_zero_match_inverted<S>(buf, mask);
}
@ -204,6 +204,8 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128
DEBUG_PRINTF("shufti %p len %zu\n", buf, buf_end - buf);
DEBUG_PRINTF("b %s\n", buf);
const u8 *buf_one_off_end = buf_end - 1;
const SuperVector<S> wide_mask1_lo(mask1_lo);
const SuperVector<S> wide_mask1_hi(mask1_hi);
const SuperVector<S> wide_mask2_lo(mask2_lo);
@ -217,24 +219,26 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128
__builtin_prefetch(d + 3*64);
__builtin_prefetch(d + 4*64);
DEBUG_PRINTF("start %p end %p \n", d, buf_end);
assert(d < buf_end);
if (d + S <= buf_end) {
assert(d < buf_one_off_end);
if (d + S <= buf_one_off_end) {
// peel off first part to cacheline boundary
DEBUG_PRINTF("until aligned %p \n", ROUNDUP_PTR(d, S));
if (!ISALIGNED_N(d, S)) {
SuperVector<S> chars = SuperVector<S>::loadu(d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d);
SuperVector<S> offset_char = SuperVector<S>::loadu(d + 1);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, offset_char, d);
DEBUG_PRINTF("rv %p \n", rv);
if (rv) return rv;
d = ROUNDUP_PTR(d, S);
}
while(d + S <= buf_end) {
while(d + S <= buf_one_off_end) {
__builtin_prefetch(d + 64);
DEBUG_PRINTF("d %p \n", d);
SuperVector<S> chars = SuperVector<S>::load(d);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, d);
SuperVector<S> offset_char = SuperVector<S>::loadu(d + 1);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, offset_char, d);
if (rv) return rv;
d += S;
}
@ -243,17 +247,19 @@ const u8 *shuftiDoubleExecReal(m128 mask1_lo, m128 mask1_hi, m128 mask2_lo, m128
DEBUG_PRINTF("tail d %p e %p \n", d, buf_end);
// finish off tail
if (d != buf_end) {
if (d < buf_one_off_end) {
SuperVector<S> chars = SuperVector<S>::Zeroes();
SuperVector<S> offset_char = SuperVector<S>::Zeroes();
const u8 *end_buf;
if (buf_end - buf < S) {
memcpy(&chars.u, buf, buf_end - buf);
end_buf = buf;
} else {
chars = SuperVector<S>::loadu(buf_end - S);
chars = SuperVector<S>::loadu(buf_one_off_end - S);
offset_char = SuperVector<S>::loadu(buf_end - S);
end_buf = buf_end - S;
}
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, end_buf);
rv = fwdBlockDouble(wide_mask1_lo, wide_mask1_hi, wide_mask2_lo, wide_mask2_hi, chars, offset_char, end_buf);
DEBUG_PRINTF("rv %p \n", rv);
if (rv && rv < buf_end) return rv;
}

View File

@ -155,31 +155,36 @@ svbool_t doubleMatched(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf, const svbool_t pg) {
svuint8_t vec = svld1_u8(pg, buf);
svuint8_t vec2 = svld1_u8(pg, buf + 1);
svuint8_t chars_lo = svand_x(svptrue_b8(), vec, (uint8_t)0xf);
svuint8_t chars_hi = svlsr_x(svptrue_b8(), vec, 4);
svuint8_t chars_lo2 = svand_x(svptrue_b8(), vec2, (uint8_t)0xf);
svuint8_t chars_hi2 = svlsr_x(svptrue_b8(), vec2, 4);
svuint8_t c1_lo = svtbl(mask1_lo, chars_lo);
svuint8_t c1_hi = svtbl(mask1_hi, chars_hi);
svuint8_t t1 = svorr_x(svptrue_b8(), c1_lo, c1_hi);
svuint8_t c2_lo = svtbl(mask2_lo, chars_lo);
svuint8_t c2_hi = svtbl(mask2_hi, chars_hi);
svuint8_t t2 = svext(svorr_z(pg, c2_lo, c2_hi), svdup_u8(0), 1);
svuint8_t c2_lo = svtbl(mask2_lo, chars_lo2);
svuint8_t c2_hi = svtbl(mask2_hi, chars_hi2);
svuint8_t t2 = svorr_x(svptrue_b8(), c2_lo, c2_hi);
svuint8_t t = svorr_x(svptrue_b8(), t1, t2);
return svnot_z(svptrue_b8(), svcmpeq(svptrue_b8(), t, (uint8_t)0xff));
return svnot_z(pg, svcmpeq(svptrue_b8(), t, (uint8_t)0xff));
}
static really_inline
const u8 *dshuftiOnce(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf, const u8 *buf_end) {
const u8 *buf_one_off_end = buf_end - 1;
DEBUG_PRINTF("start %p end %p\n", buf, buf_end);
assert(buf < buf_end);
assert(buf < buf_one_off_end);
DEBUG_PRINTF("l = %td\n", buf_end - buf);
svbool_t pg = svwhilelt_b8_s64(0, buf_end - buf);
svbool_t pg = svwhilelt_b8_s64(0, buf_one_off_end - buf);
svbool_t matched = doubleMatched(mask1_lo, mask1_hi, mask2_lo, mask2_hi,
buf, pg);
return accelSearchCheckMatched(buf, matched);
@ -199,9 +204,11 @@ static really_inline
const u8 *dshuftiSearch(svuint8_t mask1_lo, svuint8_t mask1_hi,
svuint8_t mask2_lo, svuint8_t mask2_hi,
const u8 *buf, const u8 *buf_end) {
assert(buf < buf_end);
const u8 *buf_one_off_end = buf_end - 1;
assert(buf < buf_one_off_end);
size_t len = buf_end - buf;
if (len <= svcntb()) {
if (len <= svcntb() + 1) {
return dshuftiOnce(mask1_lo, mask1_hi,
mask2_lo, mask2_hi, buf, buf_end);
}
@ -214,7 +221,7 @@ const u8 *dshuftiSearch(svuint8_t mask1_lo, svuint8_t mask1_hi,
if (ptr) return ptr;
}
buf = aligned_buf;
size_t loops = (buf_end - buf) / svcntb();
size_t loops = (buf_one_off_end - buf) / svcntb();
DEBUG_PRINTF("loops %zu \n", loops);
for (size_t i = 0; i < loops; i++, buf += svcntb()) {
const u8 *ptr = dshuftiLoopBody(mask1_lo, mask1_hi,
@ -222,9 +229,9 @@ const u8 *dshuftiSearch(svuint8_t mask1_lo, svuint8_t mask1_hi,
if (ptr) return ptr;
}
DEBUG_PRINTF("buf %p buf_end %p \n", buf, buf_end);
return buf == buf_end ? NULL : dshuftiLoopBody(mask1_lo, mask1_hi,
return buf == buf_one_off_end ? NULL : dshuftiLoopBody(mask1_lo, mask1_hi,
mask2_lo, mask2_hi,
buf_end - svcntb());
buf_one_off_end - svcntb());
}
const u8 *shuftiDoubleExec(m128 mask1_lo, m128 mask1_hi,

View File

@ -46,7 +46,7 @@ const SuperVector<S> blockSingleMask(SuperVector<S> mask_lo, SuperVector<S> mask
template <uint16_t S>
static really_inline
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars) {
SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi, SuperVector<S> mask2_lo, SuperVector<S> mask2_hi, SuperVector<S> chars, SuperVector<S> offset_char) {
const SuperVector<S> low4bits = SuperVector<S>::dup_u8(0xf);
SuperVector<S> chars_lo = chars & low4bits;
@ -60,14 +60,18 @@ SuperVector<S> blockDoubleMask(SuperVector<S> mask1_lo, SuperVector<S> mask1_hi,
SuperVector<S> c1 = c1_lo | c1_hi;
c1.print8("c1");
SuperVector<S> c2_lo = mask2_lo.pshufb(chars_lo);
SuperVector<S> chars_lo2 = offset_char & low4bits;
chars_lo.print8("chars_lo2");
SuperVector<S> chars_hi2 = offset_char.template vshr_64_imm<4>() & low4bits;
chars_hi.print8("chars_hi2");
SuperVector<S> c2_lo = mask2_lo.pshufb(chars_lo2);
c2_lo.print8("c2_lo");
SuperVector<S> c2_hi = mask2_hi.pshufb(chars_hi);
SuperVector<S> c2_hi = mask2_hi.pshufb(chars_hi2);
c2_hi.print8("c2_hi");
SuperVector<S> c2 = c2_lo | c2_hi;
c2.print8("c2");
c2.template vshr_128_imm<1>().print8("c2.vshr_128(1)");
SuperVector<S> c = c1 | (c2.template vshr_128_imm<1>());
SuperVector<S> c = c1 | c2;
c.print8("c");
return c.eq(SuperVector<S>::Ones());