diff --git a/src/util/arch/arm/simd_utils.h b/src/util/arch/arm/simd_utils.h index 68c29c67..8d8c4456 100644 --- a/src/util/arch/arm/simd_utils.h +++ b/src/util/arch/arm/simd_utils.h @@ -380,19 +380,15 @@ static really_inline m128 eq64_m128(m128 a, m128 b) { } static really_inline u32 movemask128(m128 a) { - static const uint8x16_t powers = {1, 2, 4, 8, 16, 32, 64, 128, - 1, 2, 4, 8, 16, 32, 64, 128}; - - // Compute the mask from the input - uint8x16_t mask = (uint8x16_t)vpaddlq_u32( - vpaddlq_u16(vpaddlq_u8(vandq_u8((uint8x16_t)a, powers)))); - uint8x16_t mask1 = vextq_u8(mask, (uint8x16_t)zeroes128(), 7); - mask = vorrq_u8(mask, mask1); - - // Get the resulting bytes - uint16_t output; - vst1q_lane_u16((uint16_t *)&output, (uint16x8_t)mask, 0); - return output; + ruint8x16_t input = vreinterpretq_u8_s32(a); + uint16x8_t high_bits = vreinterpretq_u16_u8(vshrq_n_u8(input, 7)); + uint32x4_t paired16 = + vreinterpretq_u32_u16(vsraq_n_u16(high_bits, high_bits, 7)); + uint64x2_t paired32 = + vreinterpretq_u64_u32(vsraq_n_u32(paired16, paired16, 14)); + uint8x16_t paired64 = + vreinterpretq_u8_u64(vsraq_n_u64(paired32, paired32, 28)); + return vgetq_lane_u8(paired64, 0) | ((int) vgetq_lane_u8(paired64, 8) << 8); } static really_inline m128 set1_16x8(u8 c) { diff --git a/src/util/supervector/supervector.hpp b/src/util/supervector/supervector.hpp index 51310db2..5d066c1a 100644 --- a/src/util/supervector/supervector.hpp +++ b/src/util/supervector/supervector.hpp @@ -104,8 +104,7 @@ struct BaseVector static constexpr bool is_valid = false; static constexpr u16 size = 8; using type = void; - using comparemask_type = void; - using cmpmask_type = void; + using comparemask_type = void; static constexpr bool has_previous = false; using previous_type = void; static constexpr u16 previous_size = 4; @@ -117,7 +116,7 @@ struct BaseVector<128> static constexpr bool is_valid = true; static constexpr u16 size = 128; using type = void; - using comparemask_type = u64a; + using comparemask_type = u64a; static constexpr bool has_previous = true; using previous_type = m512; static constexpr u16 previous_size = 64; @@ -129,7 +128,7 @@ struct BaseVector<64> static constexpr bool is_valid = true; static constexpr u16 size = 64; using type = m512; - using comparemask_type = u64a; + using comparemask_type = u64a; static constexpr bool has_previous = true; using previous_type = m256; static constexpr u16 previous_size = 32; @@ -142,7 +141,7 @@ struct BaseVector<32> static constexpr bool is_valid = true; static constexpr u16 size = 32; using type = m256; - using comparemask_type = u64a; + using comparemask_type = u64a; static constexpr bool has_previous = true; using previous_type = m128; static constexpr u16 previous_size = 16; @@ -155,7 +154,7 @@ struct BaseVector<16> static constexpr bool is_valid = true; static constexpr u16 size = 16; using type = m128; - using comparemask_type = u64a; + using comparemask_type = u64a; static constexpr bool has_previous = false; using previous_type = u64a; static constexpr u16 previous_size = 8;