From 67b414f2f9e543e894ea3204e6ce71721a0c251b Mon Sep 17 00:00:00 2001 From: Konstantinos Margaritis Date: Mon, 12 Sep 2022 13:09:51 +0000 Subject: [PATCH] [NEON] simplify/optimize shift/align primitives --- src/util/arch/arm/simd_utils.h | 220 +------------------------ src/util/supervector/arch/arm/impl.cpp | 96 ++++------- 2 files changed, 41 insertions(+), 275 deletions(-) diff --git a/src/util/arch/arm/simd_utils.h b/src/util/arch/arm/simd_utils.h index 45bcd23c..7f8539b0 100644 --- a/src/util/arch/arm/simd_utils.h +++ b/src/util/arch/arm/simd_utils.h @@ -112,43 +112,8 @@ m128 lshift_m128(m128 a, unsigned b) { return (m128) vshlq_n_u32((uint32x4_t)a, b); } #endif -#define CASE_LSHIFT_m128(a, offset) case offset: return (m128)vshlq_n_u32((uint32x4_t)(a), (offset)); break; - switch (b) { - case 0: return a; break; - CASE_LSHIFT_m128(a, 1); - CASE_LSHIFT_m128(a, 2); - CASE_LSHIFT_m128(a, 3); - CASE_LSHIFT_m128(a, 4); - CASE_LSHIFT_m128(a, 5); - CASE_LSHIFT_m128(a, 6); - CASE_LSHIFT_m128(a, 7); - CASE_LSHIFT_m128(a, 8); - CASE_LSHIFT_m128(a, 9); - CASE_LSHIFT_m128(a, 10); - CASE_LSHIFT_m128(a, 11); - CASE_LSHIFT_m128(a, 12); - CASE_LSHIFT_m128(a, 13); - CASE_LSHIFT_m128(a, 14); - CASE_LSHIFT_m128(a, 15); - CASE_LSHIFT_m128(a, 16); - CASE_LSHIFT_m128(a, 17); - CASE_LSHIFT_m128(a, 18); - CASE_LSHIFT_m128(a, 19); - CASE_LSHIFT_m128(a, 20); - CASE_LSHIFT_m128(a, 21); - CASE_LSHIFT_m128(a, 22); - CASE_LSHIFT_m128(a, 23); - CASE_LSHIFT_m128(a, 24); - CASE_LSHIFT_m128(a, 25); - CASE_LSHIFT_m128(a, 26); - CASE_LSHIFT_m128(a, 27); - CASE_LSHIFT_m128(a, 28); - CASE_LSHIFT_m128(a, 29); - CASE_LSHIFT_m128(a, 30); - CASE_LSHIFT_m128(a, 31); - default: return zeroes128(); break; - } -#undef CASE_LSHIFT_m128 + int32x4_t shift_indices = vdupq_n_s32(b); + return (m128) vshlq_s32(a, shift_indices); } static really_really_inline @@ -158,43 +123,8 @@ m128 rshift_m128(m128 a, unsigned b) { return (m128) vshrq_n_u32((uint32x4_t)a, b); } #endif -#define CASE_RSHIFT_m128(a, offset) case offset: return (m128)vshrq_n_u32((uint32x4_t)(a), (offset)); break; - switch (b) { - case 0: return a; break; - CASE_RSHIFT_m128(a, 1); - CASE_RSHIFT_m128(a, 2); - CASE_RSHIFT_m128(a, 3); - CASE_RSHIFT_m128(a, 4); - CASE_RSHIFT_m128(a, 5); - CASE_RSHIFT_m128(a, 6); - CASE_RSHIFT_m128(a, 7); - CASE_RSHIFT_m128(a, 8); - CASE_RSHIFT_m128(a, 9); - CASE_RSHIFT_m128(a, 10); - CASE_RSHIFT_m128(a, 11); - CASE_RSHIFT_m128(a, 12); - CASE_RSHIFT_m128(a, 13); - CASE_RSHIFT_m128(a, 14); - CASE_RSHIFT_m128(a, 15); - CASE_RSHIFT_m128(a, 16); - CASE_RSHIFT_m128(a, 17); - CASE_RSHIFT_m128(a, 18); - CASE_RSHIFT_m128(a, 19); - CASE_RSHIFT_m128(a, 20); - CASE_RSHIFT_m128(a, 21); - CASE_RSHIFT_m128(a, 22); - CASE_RSHIFT_m128(a, 23); - CASE_RSHIFT_m128(a, 24); - CASE_RSHIFT_m128(a, 25); - CASE_RSHIFT_m128(a, 26); - CASE_RSHIFT_m128(a, 27); - CASE_RSHIFT_m128(a, 28); - CASE_RSHIFT_m128(a, 29); - CASE_RSHIFT_m128(a, 30); - CASE_RSHIFT_m128(a, 31); - default: return zeroes128(); break; - } -#undef CASE_RSHIFT_m128 + int32x4_t shift_indices = vdupq_n_s32(-b); + return (m128) vshlq_s32(a, shift_indices); } static really_really_inline @@ -204,75 +134,8 @@ m128 lshift64_m128(m128 a, unsigned b) { return (m128) vshlq_n_u64((uint64x2_t)a, b); } #endif -#define CASE_LSHIFT64_m128(a, offset) case offset: return (m128)vshlq_n_u64((uint64x2_t)(a), (offset)); break; - switch (b) { - case 0: return a; break; - CASE_LSHIFT64_m128(a, 1); - CASE_LSHIFT64_m128(a, 2); - CASE_LSHIFT64_m128(a, 3); - CASE_LSHIFT64_m128(a, 4); - CASE_LSHIFT64_m128(a, 5); - CASE_LSHIFT64_m128(a, 6); - CASE_LSHIFT64_m128(a, 7); - CASE_LSHIFT64_m128(a, 8); - CASE_LSHIFT64_m128(a, 9); - CASE_LSHIFT64_m128(a, 10); - CASE_LSHIFT64_m128(a, 11); - CASE_LSHIFT64_m128(a, 12); - CASE_LSHIFT64_m128(a, 13); - CASE_LSHIFT64_m128(a, 14); - CASE_LSHIFT64_m128(a, 15); - CASE_LSHIFT64_m128(a, 16); - CASE_LSHIFT64_m128(a, 17); - CASE_LSHIFT64_m128(a, 18); - CASE_LSHIFT64_m128(a, 19); - CASE_LSHIFT64_m128(a, 20); - CASE_LSHIFT64_m128(a, 21); - CASE_LSHIFT64_m128(a, 22); - CASE_LSHIFT64_m128(a, 23); - CASE_LSHIFT64_m128(a, 24); - CASE_LSHIFT64_m128(a, 25); - CASE_LSHIFT64_m128(a, 26); - CASE_LSHIFT64_m128(a, 27); - CASE_LSHIFT64_m128(a, 28); - CASE_LSHIFT64_m128(a, 29); - CASE_LSHIFT64_m128(a, 30); - CASE_LSHIFT64_m128(a, 31); - CASE_LSHIFT64_m128(a, 32); - CASE_LSHIFT64_m128(a, 33); - CASE_LSHIFT64_m128(a, 34); - CASE_LSHIFT64_m128(a, 35); - CASE_LSHIFT64_m128(a, 36); - CASE_LSHIFT64_m128(a, 37); - CASE_LSHIFT64_m128(a, 38); - CASE_LSHIFT64_m128(a, 39); - CASE_LSHIFT64_m128(a, 40); - CASE_LSHIFT64_m128(a, 41); - CASE_LSHIFT64_m128(a, 42); - CASE_LSHIFT64_m128(a, 43); - CASE_LSHIFT64_m128(a, 44); - CASE_LSHIFT64_m128(a, 45); - CASE_LSHIFT64_m128(a, 46); - CASE_LSHIFT64_m128(a, 47); - CASE_LSHIFT64_m128(a, 48); - CASE_LSHIFT64_m128(a, 49); - CASE_LSHIFT64_m128(a, 50); - CASE_LSHIFT64_m128(a, 51); - CASE_LSHIFT64_m128(a, 52); - CASE_LSHIFT64_m128(a, 53); - CASE_LSHIFT64_m128(a, 54); - CASE_LSHIFT64_m128(a, 55); - CASE_LSHIFT64_m128(a, 56); - CASE_LSHIFT64_m128(a, 57); - CASE_LSHIFT64_m128(a, 58); - CASE_LSHIFT64_m128(a, 59); - CASE_LSHIFT64_m128(a, 60); - CASE_LSHIFT64_m128(a, 61); - CASE_LSHIFT64_m128(a, 62); - CASE_LSHIFT64_m128(a, 63); - default: return zeroes128(); break; - } -#undef CASE_LSHIFT64_m128 + int64x2_t shift_indices = vdupq_n_s64(b); + return (m128) vshlq_s64((int64x2_t) a, shift_indices); } static really_really_inline @@ -282,75 +145,8 @@ m128 rshift64_m128(m128 a, unsigned b) { return (m128) vshrq_n_u64((uint64x2_t)a, b); } #endif -#define CASE_RSHIFT64_m128(a, offset) case offset: return (m128)vshrq_n_u64((uint64x2_t)(a), (offset)); break; - switch (b) { - case 0: return a; break; - CASE_RSHIFT64_m128(a, 1); - CASE_RSHIFT64_m128(a, 2); - CASE_RSHIFT64_m128(a, 3); - CASE_RSHIFT64_m128(a, 4); - CASE_RSHIFT64_m128(a, 5); - CASE_RSHIFT64_m128(a, 6); - CASE_RSHIFT64_m128(a, 7); - CASE_RSHIFT64_m128(a, 8); - CASE_RSHIFT64_m128(a, 9); - CASE_RSHIFT64_m128(a, 10); - CASE_RSHIFT64_m128(a, 11); - CASE_RSHIFT64_m128(a, 12); - CASE_RSHIFT64_m128(a, 13); - CASE_RSHIFT64_m128(a, 14); - CASE_RSHIFT64_m128(a, 15); - CASE_RSHIFT64_m128(a, 16); - CASE_RSHIFT64_m128(a, 17); - CASE_RSHIFT64_m128(a, 18); - CASE_RSHIFT64_m128(a, 19); - CASE_RSHIFT64_m128(a, 20); - CASE_RSHIFT64_m128(a, 21); - CASE_RSHIFT64_m128(a, 22); - CASE_RSHIFT64_m128(a, 23); - CASE_RSHIFT64_m128(a, 24); - CASE_RSHIFT64_m128(a, 25); - CASE_RSHIFT64_m128(a, 26); - CASE_RSHIFT64_m128(a, 27); - CASE_RSHIFT64_m128(a, 28); - CASE_RSHIFT64_m128(a, 29); - CASE_RSHIFT64_m128(a, 30); - CASE_RSHIFT64_m128(a, 31); - CASE_RSHIFT64_m128(a, 32); - CASE_RSHIFT64_m128(a, 33); - CASE_RSHIFT64_m128(a, 34); - CASE_RSHIFT64_m128(a, 35); - CASE_RSHIFT64_m128(a, 36); - CASE_RSHIFT64_m128(a, 37); - CASE_RSHIFT64_m128(a, 38); - CASE_RSHIFT64_m128(a, 39); - CASE_RSHIFT64_m128(a, 40); - CASE_RSHIFT64_m128(a, 41); - CASE_RSHIFT64_m128(a, 42); - CASE_RSHIFT64_m128(a, 43); - CASE_RSHIFT64_m128(a, 44); - CASE_RSHIFT64_m128(a, 45); - CASE_RSHIFT64_m128(a, 46); - CASE_RSHIFT64_m128(a, 47); - CASE_RSHIFT64_m128(a, 48); - CASE_RSHIFT64_m128(a, 49); - CASE_RSHIFT64_m128(a, 50); - CASE_RSHIFT64_m128(a, 51); - CASE_RSHIFT64_m128(a, 52); - CASE_RSHIFT64_m128(a, 53); - CASE_RSHIFT64_m128(a, 54); - CASE_RSHIFT64_m128(a, 55); - CASE_RSHIFT64_m128(a, 56); - CASE_RSHIFT64_m128(a, 57); - CASE_RSHIFT64_m128(a, 58); - CASE_RSHIFT64_m128(a, 59); - CASE_RSHIFT64_m128(a, 60); - CASE_RSHIFT64_m128(a, 61); - CASE_RSHIFT64_m128(a, 62); - CASE_RSHIFT64_m128(a, 63); - default: return zeroes128(); break; - } -#undef CASE_RSHIFT64_m128 + int64x2_t shift_indices = vdupq_n_s64(-b); + return (m128) vshlq_s64((int64x2_t) a, shift_indices); } static really_inline m128 eq128(m128 a, m128 b) { diff --git a/src/util/supervector/arch/arm/impl.cpp b/src/util/supervector/arch/arm/impl.cpp index b3e4233e..5283ab00 100644 --- a/src/util/supervector/arch/arm/impl.cpp +++ b/src/util/supervector/arch/arm/impl.cpp @@ -374,10 +374,9 @@ template <> really_inline SuperVector<16> SuperVector<16>::vshl_8 (uint8_t const N) const { if (N == 0) return *this; - if (N == 16) return Zeroes(); - SuperVector result; - Unroller<1, 8>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshlq_n_u8(v->u.u8x16[0], n)}; }); - return result; + if (N == 8) return Zeroes(); + int8x16_t shift_indices = vdupq_n_s8(N); + return { vshlq_s8(u.s8x16[0], shift_indices) }; } template <> @@ -385,9 +384,8 @@ really_inline SuperVector<16> SuperVector<16>::vshl_16 (uint8_t const N) const { if (N == 0) return *this; if (N == 16) return Zeroes(); - SuperVector result; - Unroller<1, 16>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshlq_n_u16(v->u.u16x8[0], n)}; }); - return result; + int16x8_t shift_indices = vdupq_n_s16(N); + return { vshlq_s16(u.s16x8[0], shift_indices) }; } template <> @@ -395,9 +393,8 @@ really_inline SuperVector<16> SuperVector<16>::vshl_32 (uint8_t const N) const { if (N == 0) return *this; if (N == 32) return Zeroes(); - SuperVector result; - Unroller<1, 32>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshlq_n_u32(v->u.u32x4[0], n)}; }); - return result; + int32x4_t shift_indices = vdupq_n_s32(N); + return { vshlq_s32(u.s32x4[0], shift_indices) }; } template <> @@ -405,9 +402,8 @@ really_inline SuperVector<16> SuperVector<16>::vshl_64 (uint8_t const N) const { if (N == 0) return *this; if (N == 64) return Zeroes(); - SuperVector result; - Unroller<1, 64>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshlq_n_u64(v->u.u64x2[0], n)}; }); - return result; + int64x2_t shift_indices = vdupq_n_s64(N); + return { vshlq_s64(u.s64x2[0], shift_indices) }; } template <> @@ -415,6 +411,11 @@ really_inline SuperVector<16> SuperVector<16>::vshl_128(uint8_t const N) const { if (N == 0) return *this; if (N == 16) return Zeroes(); +#if defined(HAVE__BUILTIN_CONSTANT_P) + if (__builtin_constant_p(N)) { + return {vextq_u8(vdupq_n_u8(0), u.u8x16[0], 16 - N)}; + } +#endif SuperVector result; Unroller<1, 16>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vextq_u8(vdupq_n_u8(0), v->u.u8x16[0], 16 - n)}; }); return result; @@ -431,9 +432,8 @@ really_inline SuperVector<16> SuperVector<16>::vshr_8 (uint8_t const N) const { if (N == 0) return *this; if (N == 8) return Zeroes(); - SuperVector result; - Unroller<1, 8>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshrq_n_u8(v->u.u8x16[0], n)}; }); - return result; + int8x16_t shift_indices = vdupq_n_s8(-N); + return { vshlq_s8(u.s8x16[0], shift_indices) }; } template <> @@ -441,9 +441,8 @@ really_inline SuperVector<16> SuperVector<16>::vshr_16 (uint8_t const N) const { if (N == 0) return *this; if (N == 16) return Zeroes(); - SuperVector result; - Unroller<1, 16>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshrq_n_u16(v->u.u16x8[0], n)}; }); - return result; + int16x8_t shift_indices = vdupq_n_s16(-N); + return { vshlq_s16(u.s16x8[0], shift_indices) }; } template <> @@ -451,9 +450,8 @@ really_inline SuperVector<16> SuperVector<16>::vshr_32 (uint8_t const N) const { if (N == 0) return *this; if (N == 32) return Zeroes(); - SuperVector result; - Unroller<1, 32>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshrq_n_u32(v->u.u32x4[0], n)}; }); - return result; + int32x4_t shift_indices = vdupq_n_s32(-N); + return { vshlq_s32(u.s32x4[0], shift_indices) }; } template <> @@ -461,9 +459,8 @@ really_inline SuperVector<16> SuperVector<16>::vshr_64 (uint8_t const N) const { if (N == 0) return *this; if (N == 64) return Zeroes(); - SuperVector result; - Unroller<1, 64>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vshrq_n_u64(v->u.u64x2[0], n)}; }); - return result; + int64x2_t shift_indices = vdupq_n_s64(-N); + return { vshlq_s64(u.s64x2[0], shift_indices) }; } template <> @@ -471,6 +468,11 @@ really_inline SuperVector<16> SuperVector<16>::vshr_128(uint8_t const N) const { if (N == 0) return *this; if (N == 16) return Zeroes(); +#if defined(HAVE__BUILTIN_CONSTANT_P) + if (__builtin_constant_p(N)) { + return {vextq_u8(u.u8x16[0], vdupq_n_u8(0), N)}; + } +#endif SuperVector result; Unroller<1, 16>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (N == n) result = {vextq_u8(v->u.u8x16[0], vdupq_n_u8(0), n)}; }); return result; @@ -485,22 +487,12 @@ really_inline SuperVector<16> SuperVector<16>::vshr(uint8_t const N) const template <> really_inline SuperVector<16> SuperVector<16>::operator>>(uint8_t const N) const { -#if defined(HAVE__BUILTIN_CONSTANT_P) - if (__builtin_constant_p(N)) { - return {vextq_u8(u.u8x16[0], vdupq_n_u8(0), N)}; - } -#endif return vshr_128(N); } template <> really_inline SuperVector<16> SuperVector<16>::operator<<(uint8_t const N) const { -#if defined(HAVE__BUILTIN_CONSTANT_P) - if (__builtin_constant_p(N)) { - return {vextq_u8(vdupq_n_u8(0), u.u8x16[0], 16 - N)}; - } -#endif return vshl_128(N); } @@ -534,45 +526,23 @@ template <> really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint8_t const len) { SuperVector mask = Ones_vshr(16 -len); - //mask.print8("mask"); SuperVector<16> v = loadu(ptr); - //v.print8("v"); return mask & v; } template<> really_inline SuperVector<16> SuperVector<16>::alignr(SuperVector<16> &other, int8_t offset) { + if (offset == 0) return other; + if (offset == 16) return *this; #if defined(HAVE__BUILTIN_CONSTANT_P) if (__builtin_constant_p(offset)) { - if (offset == 16) { - return *this; - } else { - return {vextq_u8(other.u.u8x16[0], u.u8x16[0], offset)}; - } + return {vextq_u8(other.u.u8x16[0], u.u8x16[0], offset)}; } #endif - switch(offset) { - case 0: return other; break; - case 1: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 1)}; break; - case 2: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 2)}; break; - case 3: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 3)}; break; - case 4: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 4)}; break; - case 5: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 5)}; break; - case 6: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 6)}; break; - case 7: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 7)}; break; - case 8: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 8)}; break; - case 9: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 9)}; break; - case 10: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 10)}; break; - case 11: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 11)}; break; - case 12: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 12)}; break; - case 13: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 13)}; break; - case 14: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 14)}; break; - case 15: return {vextq_u8( other.u.u8x16[0], u.u8x16[0], 15)}; break; - case 16: return *this; break; - default: break; - } - return *this; + SuperVector result; + Unroller<1, 16>::iterator([&,v=this](auto const i) { constexpr uint8_t n = i.value; if (offset == n) result = {vextq_u8(other.u.u8x16[0], v->u.u8x16[0], n)}; }); + return result; } template<>