diff --git a/src/hwlm/noodle_engine_sve.hpp b/src/hwlm/noodle_engine_sve.hpp index 24e0ae2a..fcf2f5b4 100644 --- a/src/hwlm/noodle_engine_sve.hpp +++ b/src/hwlm/noodle_engine_sve.hpp @@ -38,12 +38,8 @@ hwlm_error_t checkMatched(const struct noodTable *n, const u8 *buf, size_t len, size_t matchPos = basePos + svcntp_b8(svptrue_b8(), brk); DEBUG_PRINTF("match pos %zu\n", matchPos); assert(matchPos < len); - size_t end_of_match_pos = matchPos - cbi->offsetAdj + n->key_offset - 1; - // doubleMatched can add a fake \0 at the end of the buffer. This check get rid of any match that might include it - if(end_of_match_pos < len) { - hwlmcb_rv_t rv = final(n, buf, len, needsConfirm, cbi, matchPos); - RETURN_IF_TERMINATED(rv); - } + hwlmcb_rv_t rv = final(n, buf, len, needsConfirm, cbi, matchPos); + RETURN_IF_TERMINATED(rv); next_match = svpnext_b8(matched, next_match); } while (unlikely(svptest_any(svptrue_b8(), next_match))); return HWLM_SUCCESS; @@ -152,18 +148,14 @@ hwlm_error_t doubleCheckMatched(const struct noodTable *n, const u8 *buf, } static really_inline -svbool_t doubleMatched(svuint16_t chars, const u8 *d, - svbool_t pg, svbool_t pg_rot, +svbool_t doubleMatchedLoop(svuint16_t chars, const u8 *d, svbool_t * const matched, svbool_t * const matched_rot) { - svuint16_t vec = svreinterpret_u16(svld1_u8(pg, d)); + svuint16_t vec = svreinterpret_u16(svld1_u8(svptrue_b8(), d)); // d - 1 won't underflow as the first position in buf has been dealt // with meaning that d > buf - svuint16_t vec_rot = svreinterpret_u16(svld1_u8(pg_rot, d - 1)); - // we reuse u8 predicates for u16 lanes. This means that we may actually check against one - // undefined extra character at the end of the buffer (usually \0). We check it later to - // reject this spurious match - *matched = svmatch(pg, vec, chars); - *matched_rot = svmatch(pg_rot, vec_rot, chars); + svuint16_t vec_rot = svreinterpret_u16(svld1_u8(svptrue_b8(), d - 1)); + *matched = svmatch(svptrue_b8(), vec, chars); + *matched_rot = svmatch(svptrue_b8(), vec_rot, chars); return svorr_z(svptrue_b8(), *matched, *matched_rot); } @@ -174,10 +166,34 @@ hwlm_error_t scanDoubleOnce(const struct noodTable *n, const u8 *buf, DEBUG_PRINTF("start %p end %p\n", d, e); assert(d < e); assert(d > buf); - svbool_t pg = svwhilelt_b8_s64(0, e - d); - svbool_t pg_rot = svwhilelt_b8_s64(0, e - d + 1); - svbool_t matched, matched_rot; - svbool_t any = doubleMatched(svreinterpret_u16(chars), d, pg, pg_rot, &matched, &matched_rot); + const ptrdiff_t size = e - d; + svbool_t pg = svwhilelt_b8_s64(0, size); + svbool_t pg_rot = svwhilelt_b8_s64(0, size + 1); + + svuint16_t vec = svreinterpret_u16(svld1_u8(pg, d)); + // d - 1 won't underflow as the first position in buf has been dealt + // with meaning that d > buf + svuint16_t vec_rot = svreinterpret_u16(svld1_u8(pg_rot, d - 1)); + + // we reuse u8 predicates for u16 lanes. This means that we will check against one + // extra \0 character at the end of the vector. + if(unlikely(n->key1 == '\0')) { + if (size % 2) { + // if odd, vec has an odd number of lanes and has the spurious \0 + svbool_t lane_to_disable = svrev_b8(svpfirst(svrev_b8(pg), svpfalse())); + pg = sveor_z(svptrue_b8(), pg, lane_to_disable); + } else { + // if even, vec_rot has an odd number of lanes and has the spurious \0 + // we need to disable the last active lane as well, but we know pg is + // the same as pg_rot without the last lane + pg_rot = pg; + } + } + + svbool_t matched = svmatch(pg, vec, svreinterpret_u16(chars)); + svbool_t matched_rot = svmatch(pg_rot, vec_rot, svreinterpret_u16(chars)); + svbool_t any = svorr_z(svptrue_b8(), matched, matched_rot); + return doubleCheckMatched(n, buf, len, cbi, d, matched, matched_rot, any); } @@ -194,8 +210,7 @@ hwlm_error_t scanDoubleLoop(const struct noodTable *n, const u8 *buf, for (size_t i = 0; i < loops; i++, d += svcntb()) { DEBUG_PRINTF("d %p \n", d); svbool_t matched, matched_rot; - svbool_t any = doubleMatched(svreinterpret_u16(chars), d, svptrue_b8(), svptrue_b8(), - &matched, &matched_rot); + svbool_t any = doubleMatchedLoop(svreinterpret_u16(chars), d, &matched, &matched_rot); hwlm_error_t rv = doubleCheckMatched(n, buf, len, cbi, d, matched, matched_rot, any); RETURN_IF_TERMINATED(rv);