diff --git a/src/util/supervector/arch/x86/impl.cpp b/src/util/supervector/arch/x86/impl.cpp index b8a75c95..77ffc038 100644 --- a/src/util/supervector/arch/x86/impl.cpp +++ b/src/util/supervector/arch/x86/impl.cpp @@ -524,7 +524,28 @@ really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint { SuperVector mask = Ones_vshr(16 -len); SuperVector v = _mm_loadu_si128((const m128 *)ptr); - return mask & v; + return v & mask; +} + +template <> +really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask) +{ +#ifdef HAVE_AVX512 + SuperVector<16> v = _mm_maskz_loadu_epi8(mask, (const m128 *)ptr); + v.print8("v"); + return v; +#else + DEBUG_PRINTF("mask = %08x\n", mask); + SuperVector v = _mm_loadu_si128((const m128 *)ptr); + (void)mask; + return v; // FIXME: & mask +#endif +} + +template<> +really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z) +{ + return findAndClearLSB_32(&z); } template<> @@ -1126,22 +1147,35 @@ really_inline SuperVector<32> SuperVector<32>::load(void const *ptr) template <> really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, uint8_t const len) { + SuperVector mask = Ones_vshr(32 -len); + mask.print8("mask"); + SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr); + v.print8("v"); + return v & mask; +} + +template <> +really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask) +{ + DEBUG_PRINTF("mask = %08llx\n", mask); #ifdef HAVE_AVX512 - u32 mask = (~0ULL) >> (32 - len); - SuperVector<32> v = _mm256_mask_loadu_epi8(Zeroes().u.v256[0], mask, (const m256 *)ptr); + SuperVector<32> v = _mm256_maskz_loadu_epi8(mask, (const m256 *)ptr); v.print8("v"); return v; #else - DEBUG_PRINTF("len = %d", len); - SuperVector<32> mask = Ones_vshr(32 -len); - mask.print8("mask"); - (Ones() >> (32 - len)).print8("mask"); SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr); v.print8("v"); - return mask & v; + (void)mask; + return v; // FIXME: & mask #endif } +template<> +really_inline typename SuperVector<32>::comparemask_type SuperVector<32>::findLSB(typename SuperVector<32>::comparemask_type &z) +{ + return findAndClearLSB_64(&z); +} + template<> really_inline SuperVector<32> SuperVector<32>::alignr(SuperVector<32> &other, int8_t offset) { @@ -1778,11 +1812,26 @@ really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, uint { u64a mask = (~0ULL) >> (64 - len); DEBUG_PRINTF("mask = %016llx\n", mask); - SuperVector<64> v = _mm512_mask_loadu_epi8(Zeroes().u.v512[0], mask, (const m512 *)ptr); + SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr); v.print8("v"); return v; } +template <> +really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask) +{ + DEBUG_PRINTF("mask = %016llx\n", mask); + SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr); + v.print8("v"); + return v; +} + +template<> +really_inline typename SuperVector<64>::comparemask_type SuperVector<64>::findLSB(typename SuperVector<64>::comparemask_type &z) +{ + return findAndClearLSB_64(&z); +} + template<> template<> really_inline SuperVector<64> SuperVector<64>::pshufb(SuperVector<64> b) diff --git a/src/util/supervector/supervector.hpp b/src/util/supervector/supervector.hpp index 253907fa..1d72ee81 100644 --- a/src/util/supervector/supervector.hpp +++ b/src/util/supervector/supervector.hpp @@ -46,34 +46,18 @@ #endif #endif // VS_SIMDE_BACKEND +#include + #if defined(HAVE_SIMD_512_BITS) -using Z_TYPE = u64a; -#define Z_BITS 64 -#define Z_SHIFT 63 #define Z_POSSHIFT 0 -#define DOUBLE_LOAD_MASK(l) ((~0ULL) >> (Z_BITS -(l))) -#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL) #elif defined(HAVE_SIMD_256_BITS) -using Z_TYPE = u32; -#define Z_BITS 32 -#define Z_SHIFT 31 #define Z_POSSHIFT 0 -#define DOUBLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL) -#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL) #elif defined(HAVE_SIMD_128_BITS) #if !defined(VS_SIMDE_BACKEND) && (defined(ARCH_ARM32) || defined(ARCH_AARCH64)) -using Z_TYPE = u64a; -#define Z_BITS 64 #define Z_POSSHIFT 2 -#define DOUBLE_LOAD_MASK(l) ((~0ULL) >> (Z_BITS - (l))) #else -using Z_TYPE = u32; -#define Z_BITS 32 #define Z_POSSHIFT 0 -#define DOUBLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL) #endif -#define Z_SHIFT 15 -#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL) #endif // Define a common assume_aligned using an appropriate compiler built-in, if @@ -138,7 +122,7 @@ struct BaseVector<64> static constexpr u16 previous_size = 32; }; -// 128 bit implementation +// 256 bit implementation template <> struct BaseVector<32> { @@ -158,7 +142,7 @@ struct BaseVector<16> static constexpr bool is_valid = true; static constexpr u16 size = 16; using type = m128; - using comparemask_type = u64a; + using comparemask_type = u32; static constexpr bool has_previous = false; using previous_type = u64a; static constexpr u16 previous_size = 8; @@ -257,9 +241,13 @@ public: static typename base_type::comparemask_type iteration_mask(typename base_type::comparemask_type mask); + static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); } + static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); } + static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z); static SuperVector loadu(void const *ptr); static SuperVector load(void const *ptr); static SuperVector loadu_maskz(void const *ptr, uint8_t const len); + static SuperVector loadu_maskz(void const *ptr, typename base_type::comparemask_type const len); SuperVector alignr(SuperVector &other, int8_t offset); template