diff --git a/src/rose/counting_miracle.h b/src/rose/counting_miracle.h index 668de996..d61cc12c 100644 --- a/src/rose/counting_miracle.h +++ b/src/rose/counting_miracle.h @@ -1,5 +1,6 @@ /* * Copyright (c) 2015-2017, Intel Corporation + * Copyright (c) 2021, Arm Limited * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -41,15 +42,14 @@ #ifdef HAVE_SVE2 - static really_inline -size_t countMatches(const svuint8_t chars, const svbool_t pg, const u8 *buf) { +size_t countMatches(svuint8_t chars, svbool_t pg, const u8 *buf) { svuint8_t vec = svld1_u8(pg, buf); return svcntp_b8(svptrue_b8(), svmatch(pg, vec, chars)); } static really_inline -bool countLoopBody(const svuint8_t chars, const svbool_t pg, const u8 *d, +bool countLoopBody(svuint8_t chars, svbool_t pg, const u8 *d, u32 target_count, u32 *count_inout, const u8 **d_out) { *count_inout += countMatches(chars, pg, d); if (*count_inout >= target_count) { @@ -60,7 +60,7 @@ bool countLoopBody(const svuint8_t chars, const svbool_t pg, const u8 *d, } static really_inline -bool countOnce(const svuint8_t chars, const u8 *d, const u8 *d_end, +bool countOnce(svuint8_t chars, const u8 *d, const u8 *d_end, u32 target_count, u32 *count_inout, const u8 **d_out) { assert(d <= d_end); svbool_t pg = svwhilelt_b8_s64(0, d_end - d); @@ -145,6 +145,74 @@ char roseCountingMiracleScan(u8 c, const u8 *d, const u8 *d_end, #endif +#ifdef HAVE_SVE + +static really_inline +size_t countShuftiMatches(svuint8_t mask_lo, svuint8_t mask_hi, + const svbool_t pg, const u8 *buf) { + svuint8_t vec = svld1_u8(pg, buf); + svuint8_t c_lo = svtbl(mask_lo, svand_z(svptrue_b8(), vec, (uint8_t)0xf)); + svuint8_t c_hi = svtbl(mask_hi, svlsr_z(svptrue_b8(), vec, 4)); + svuint8_t t = svand_z(svptrue_b8(), c_lo, c_hi); + return svcntp_b8(svptrue_b8(), svcmpne(pg, t, (uint8_t)0)); +} + +static really_inline +bool countShuftiLoopBody(svuint8_t mask_lo, svuint8_t mask_hi, + const svbool_t pg, const u8 *d, u32 target_count, + u32 *count_inout, const u8 **d_out) { + *count_inout += countShuftiMatches(mask_lo, mask_hi, pg, d); + if (*count_inout >= target_count) { + *d_out = d; + return true; + } + return false; +} + +static really_inline +bool countShuftiOnce(svuint8_t mask_lo, svuint8_t mask_hi, + const u8 *d, const u8 *d_end, u32 target_count, + u32 *count_inout, const u8 **d_out) { + svbool_t pg = svwhilelt_b8_s64(0, d_end - d); + return countShuftiLoopBody(mask_lo, mask_hi, pg, d, target_count, + count_inout, d_out); +} + +static really_inline +bool roseCountingMiracleScanShufti(svuint8_t mask_lo, svuint8_t mask_hi, + UNUSED u8 poison, const u8 *d, + const u8 *d_end, u32 target_count, + u32 *count_inout, const u8 **d_out) { + assert(d <= d_end); + size_t len = d_end - d; + if (len <= svcntb()) { + char rv = countShuftiOnce(mask_lo, mask_hi, d, d_end, target_count, + count_inout, d_out); + return rv; + } + // peel off first part to align to the vector size + const u8 *aligned_d_end = ROUNDDOWN_PTR(d_end, svcntb_pat(SV_POW2)); + assert(d < aligned_d_end); + if (d_end != aligned_d_end) { + if (countShuftiOnce(mask_lo, mask_hi, aligned_d_end, d_end, + target_count, count_inout, d_out)) return true; + d_end = aligned_d_end; + } + size_t loops = (d_end - d) / svcntb(); + for (size_t i = 0; i < loops; i++) { + d_end -= svcntb(); + if (countShuftiLoopBody(mask_lo, mask_hi, svptrue_b8(), d_end, + target_count, count_inout, d_out)) return true; + } + if (d != d_end) { + if (countShuftiOnce(mask_lo, mask_hi, d, d_end, + target_count, count_inout, d_out)) return true; + } + return false; +} + +#else + #define GET_LO_4(chars) and128(chars, low4bits) #define GET_HI_4(chars) rshift64_m128(andnot128(low4bits, chars), 4) @@ -198,6 +266,8 @@ u32 roseCountingMiracleScanShufti(m128 mask_lo, m128 mask_hi, u8 poison, return 0; } +#endif + /** * \brief "Counting Miracle" scan: If we see more than N instances of a * particular character class we know that the engine must be dead. @@ -277,8 +347,13 @@ int roseCountingMiracleOccurs(const struct RoseEngine *t, } } } else { +#ifdef HAVE_SVE + svuint8_t lo = getSVEMaskFrom128(cm->lo); + svuint8_t hi = getSVEMaskFrom128(cm->hi); +#else m128 lo = cm->lo; m128 hi = cm->hi; +#endif u8 poison = cm->poison; // Scan buffer.