minor fixes, add 2 constructors from half size vectors

This commit is contained in:
Konstantinos Margaritis 2021-07-23 11:43:10 +03:00
parent cabd13d18a
commit f8ce0bb922

View File

@ -32,6 +32,7 @@
#include <cstdint>
#include <cstdio>
#include <type_traits>
#if defined(ARCH_IA32) || defined(ARCH_X86_64)
#include "util/supervector/arch/x86/types.hpp"
@ -88,48 +89,63 @@ using m1024_t = SuperVector<128>;
template <int T>
struct BaseVector
{
static const bool is_valid = false; // for template matches specialisation
using type = void;
using movemask_type = uint32_t;
using previous_type = void;
static constexpr bool is_valid = false;
static constexpr u16 size = 8;
using type = void;
using movemask_type = void;
static constexpr bool has_previous = false;
using previous_type = void;
static constexpr u16 previous_size = 4;
};
template <>
struct BaseVector<128>
{
static constexpr bool is_valid = true;
static constexpr uint16_t size = 128;
using type = void;
using movemask_type = u64a;
static constexpr bool is_valid = true;
static constexpr u16 size = 128;
using type = void;
using movemask_type = u64a;
static constexpr bool has_previous = true;
using previous_type = m512;
static constexpr u16 previous_size = 64;
};
template <>
struct BaseVector<64>
{
static constexpr bool is_valid = true;
static constexpr uint16_t size = 64;
using type = m512;
using movemask_type = u64a;
static constexpr bool is_valid = true;
static constexpr u16 size = 64;
using type = m512;
using movemask_type = u64a;
static constexpr bool has_previous = true;
using previous_type = m256;
static constexpr u16 previous_size = 32;
};
// 128 bit implementation
template <>
struct BaseVector<32>
{
static constexpr bool is_valid = true;
static constexpr uint16_t size = 32;
using type = m256;
using movemask_type = u32;
static constexpr bool is_valid = true;
static constexpr u16 size = 32;
using type = m256;
using movemask_type = u32;
static constexpr bool has_previous = true;
using previous_type = m128;
static constexpr u16 previous_size = 16;
};
// 128 bit implementation
template <>
struct BaseVector<16>
{
static constexpr bool is_valid = true;
static constexpr uint16_t size = 16;
using type = m128;
using movemask_type = u32;
static constexpr bool is_valid = true;
static constexpr u16 size = 16;
using type = m128;
using movemask_type = u32;
static constexpr bool has_previous = false;
using previous_type = u64a;
static constexpr u16 previous_size = 8;
};
template <uint16_t SIZE>
@ -140,6 +156,7 @@ class SuperVector : public BaseVector<SIZE>
public:
using base_type = BaseVector<SIZE>;
using previous_type = typename BaseVector<SIZE>::previous_type;
union {
typename BaseVector<16>::type ALIGN_ATTR(BaseVector<16>::size) v128[SIZE / BaseVector<16>::size];
@ -164,6 +181,9 @@ public:
template<typename T>
SuperVector(T const other);
SuperVector(SuperVector<SIZE/2> const lo, SuperVector<SIZE/2> const hi);
SuperVector(previous_type const lo, previous_type const hi);
static SuperVector dup_u8 (uint8_t other) { return {other}; };
static SuperVector dup_s8 (int8_t other) { return {other}; };
static SuperVector dup_u16(uint16_t other) { return {other}; };
@ -208,38 +228,38 @@ public:
static SuperVector Zeroes();
#if defined(DEBUG)
void print8(const char *label) {
void print8(const char *label) const {
printf("%12s: ", label);
for(s16 i=SIZE-1; i >= 0; i--)
printf("%02x ", u.u8[i]);
printf("\n");
}
void print16(const char *label) {
void print16(const char *label) const {
printf("%12s: ", label);
for(s16 i=SIZE/sizeof(u16)-1; i >= 0; i--)
printf("%04x ", u.u16[i]);
printf("\n");
}
void print32(const char *label) {
void print32(const char *label) const {
printf("%12s: ", label);
for(s16 i=SIZE/sizeof(u32)-1; i >= 0; i--)
printf("%08x ", u.u32[i]);
printf("\n");
}
void print64(const char *label) {
void print64(const char *label) const {
printf("%12s: ", label);
for(s16 i=SIZE/sizeof(u64a)-1; i >= 0; i--)
printf("%016lx ", u.u64[i]);
printf("\n");
}
#else
void print8(const char *label UNUSED) {};
void print16(const char *label UNUSED) {};
void print32(const char *label UNUSED) {};
void print64(const char *label UNUSED) {};
void print8(const char *label UNUSED) const {};
void print16(const char *label UNUSED) const {};
void print32(const char *label UNUSED) const {};
void print64(const char *label UNUSED) const {};
#endif
};