From db3b0e9474cdd063b60c63dfcb340ebd41b8acca Mon Sep 17 00:00:00 2001 From: Konstantinos Margaritis Date: Mon, 18 Dec 2023 20:23:07 +0000 Subject: [PATCH] comparemask_type is u64a on Arm, use single load_mask --- src/hwlm/noodle_engine_simd.hpp | 17 +++++++++++++---- src/util/supervector/arch/arm/impl.cpp | 19 +++++++++++++++++-- src/util/supervector/supervector.hpp | 7 +++++-- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/hwlm/noodle_engine_simd.hpp b/src/hwlm/noodle_engine_simd.hpp index 9af76768..23827873 100644 --- a/src/hwlm/noodle_engine_simd.hpp +++ b/src/hwlm/noodle_engine_simd.hpp @@ -86,15 +86,21 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf, DEBUG_PRINTF("d - d0: %ld \n", d - d0); #if defined(HAVE_MASKED_LOADS) uint8_t l = d - d0; - typename SuperVector::comparemask_type mask = ~SuperVector::single_load_mask(l); + typename SuperVector::comparemask_type mask = ~SuperVector::load_mask(l); SuperVector chars = SuperVector::loadu_maskz(d0, mask) & caseMask; typename SuperVector::comparemask_type z = mask1.eqmask(chars); DEBUG_PRINTF("mask: %08llx\n", mask); hwlm_error_t rv = single_zscan(n, d0, buf, z, len, cbi); #else uint8_t l = d0 + S - d; + DEBUG_PRINTF("l: %d \n", l); SuperVector chars = SuperVector::loadu_maskz(d, l) & caseMask; + chars.print8("chars"); typename SuperVector::comparemask_type z = mask1.eqmask(chars); + DEBUG_PRINTF("z: %08llx\n", (u64a) z); + z = SuperVector::iteration_mask(z); + DEBUG_PRINTF("z: %08llx\n", (u64a) z); + hwlm_error_t rv = single_zscan(n, d, buf, z, len, cbi); #endif chars.print32("chars"); @@ -125,6 +131,8 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf, uint8_t l = buf_end - d; SuperVector chars = SuperVector::loadu_maskz(d, l) & caseMask; typename SuperVector::comparemask_type z = mask1.eqmask(chars); + z = SuperVector::iteration_mask(z); + hwlm_error_t rv = single_zscan(n, d, buf, z, len, cbi); RETURN_IF_TERMINATED(rv); } @@ -160,12 +168,12 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf, const u8 *d0 = ROUNDDOWN_PTR(d, S); #if defined(HAVE_MASKED_LOADS) uint8_t l = d - d0; - typename SuperVector::comparemask_type mask = ~SuperVector::double_load_mask(l); + typename SuperVector::comparemask_type mask = ~SuperVector::load_mask(l); SuperVector chars = SuperVector::loadu_maskz(d0, mask) & caseMask; typename SuperVector::comparemask_type z1 = mask1.eqmask(chars); typename SuperVector::comparemask_type z2 = mask2.eqmask(chars); typename SuperVector::comparemask_type z = (z1 << SuperVector::mask_width()) & z2; - DEBUG_PRINTF("z: %0llx\n", z); + z = SuperVector::iteration_mask(z); lastz1 = z1 >> (S - 1); DEBUG_PRINTF("mask: %08llx\n", mask); @@ -176,8 +184,9 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf, chars.print8("chars"); typename SuperVector::comparemask_type z1 = mask1.eqmask(chars); typename SuperVector::comparemask_type z2 = mask2.eqmask(chars); - typename SuperVector::comparemask_type z = (z1 << SuperVector::mask_width()) & z2; + z = SuperVector::iteration_mask(z); + hwlm_error_t rv = double_zscan(n, d, buf, z, len, cbi); lastz1 = z1 >> (l - 1); #endif diff --git a/src/util/supervector/arch/arm/impl.cpp b/src/util/supervector/arch/arm/impl.cpp index 55f6c55c..bd866223 100644 --- a/src/util/supervector/arch/arm/impl.cpp +++ b/src/util/supervector/arch/arm/impl.cpp @@ -525,11 +525,26 @@ really_inline SuperVector<16> SuperVector<16>::load(void const *ptr) template <> really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint8_t const len) { - SuperVector mask = Ones_vshr(16 -len); - SuperVector<16> v = loadu(ptr); + SuperVector mask = Ones_vshr(16 - len); + SuperVector v = loadu(ptr); return mask & v; } +template <> +really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask) +{ + DEBUG_PRINTF("mask = %08llx\n", mask); + SuperVector v = loadu(ptr); + (void)mask; + return v; // FIXME: & mask +} + +template<> +really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z) +{ + return findAndClearLSB_64(&z) >> 2; +} + template<> really_inline SuperVector<16> SuperVector<16>::alignr(SuperVector<16> &other, int8_t offset) { diff --git a/src/util/supervector/supervector.hpp b/src/util/supervector/supervector.hpp index 3c4b1eea..6d2bc809 100644 --- a/src/util/supervector/supervector.hpp +++ b/src/util/supervector/supervector.hpp @@ -130,7 +130,11 @@ struct BaseVector<16> static constexpr bool is_valid = true; static constexpr u16 size = 16; using type = m128; +#if defined(ARCH_ARM32) || defined(ARCH_AARCH64) + using comparemask_type = u64a; +#else using comparemask_type = u32; +#endif static constexpr bool has_previous = false; using previous_type = u64a; static constexpr u16 previous_size = 8; @@ -229,8 +233,7 @@ public: static typename base_type::comparemask_type iteration_mask(typename base_type::comparemask_type mask); - static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); } - static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); } + static typename base_type::comparemask_type load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); } static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z); static SuperVector loadu(void const *ptr); static SuperVector load(void const *ptr);