Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-19 22:06:56 +00:00
parent 0b87df9c4a
commit 9d088bc569
15 changed files with 1602 additions and 588 deletions

View File

@@ -360,10 +360,9 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
if constexpr(ck::is_packed_type_v<ck::remove_cvref_t<T>>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v<ck::remove_cvref_t<T>>;
}
else
{
@@ -516,69 +515,31 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v<ck::remove_cvref_t<T>>;
}
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
ck::packed_size_v<ck::remove_cvref_t<T>>];
}
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
ck::packed_size_v<ck::remove_cvref_t<T>>];
}
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
typename Data::iterator begin() { return mData.begin(); }

View File

@@ -67,6 +67,18 @@ struct GeneratorTensor_1<ck::f8_t>
return ck::type_convert<ck::f8_t>(value);
}
};
template <>
struct GeneratorTensor_1<ck::bf8_t>
{
float value = 1.0;
template <typename... Is>
ck::bf8_t operator()(Is...)
{
return ck::type_convert<ck::bf8_t>(value);
}
};
#endif
template <>
@@ -93,6 +105,38 @@ struct GeneratorTensor_1<ck::f4x2_pk_t>
}
};
template <>
struct GeneratorTensor_1<ck::f6x32_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::f6_t>(value), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_1<ck::bf6x32_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::bf6_t>(value), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
@@ -132,6 +176,44 @@ struct GeneratorTensor_2
}
};
template <>
struct GeneratorTensor_2<ck::f6x32_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float tmp = (std::rand() % (max_value - min_value)) + min_value;
r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_2<ck::bf6x32_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float tmp = (std::rand() % (max_value - min_value)) + min_value;
r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_2<ck::bhalf_t>
{
@@ -342,6 +424,46 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
}
};
template <>
struct GeneratorTensor_3<ck::f6x32_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float rnd = float(std::rand()) / float(RAND_MAX);
float fp32 = min_value + rnd * (max_value - min_value);
r.pack(ck::type_convert<ck::f6_t>(fp32), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_3<ck::bf6x32_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float rnd = float(std::rand()) / float(RAND_MAX);
float fp32 = min_value + rnd * (max_value - min_value);
r.pack(ck::type_convert<ck::bf6_t>(fp32), static_cast<ck::index_t>(i));
});
return r;
}
};
template <typename T>
struct GeneratorTensor_4
{
@@ -360,6 +482,69 @@ struct GeneratorTensor_4
}
};
template <>
struct GeneratorTensor_4<ck::f4x2_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float fp32_tmp0 = distribution(generator);
float fp32_tmp1 = distribution(generator);
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
}
};
template <>
struct GeneratorTensor_4<ck::f6x32_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::f6_t>(distribution(generator)),
static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_4<ck::bf6x32_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::bf6_t>(distribution(generator)),
static_cast<ck::index_t>(i));
});
return r;
}
};
struct GeneratorTensor_Checkboard
{
template <typename... Ts>
@@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::f4x2_pk_t, Dim>
{
template <typename... Ts>
ck::f4x2_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
return ck::type_convert<ck::f4x2_t>(ck::float2_t(tmp));
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::f6x32_pk_t, Dim>
{
template <typename... Ts>
ck::f6x32_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}(
[&](auto i) { r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i)); });
return r;
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::bf6x32_pk_t, Dim>
{
template <typename... Ts>
ck::bf6x32_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}(
[&](auto i) { r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i)); });
return r;
}
};
template <typename T, size_t NumEffectiveDim = 2>
struct GeneratorTensor_Diagonal
{

View File

@@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
@@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
@@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
@@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const bf8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
// constant as F32 constant. I.e., if scale_a at some point was declared as
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
// when OPSEL is set otherwise.
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
@@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a,
const int32_t scale_a,
const f6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a,
const int32_t scale_a,
const bf6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
@@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
4, // blgp
0, // OPSEL
scale_a,
@@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a,
const int32_t scale_a,
const f6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a,
const int32_t scale_a,
const bf6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
@@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
@@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
@@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
@@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};

View File

@@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format
// scalar_type
template <typename TV>
struct scalar_type;
struct f4x2_pk_t
{
static constexpr int packed_size = 2;
using type = uint8_t;
type data;
__host__ __device__ f4x2_pk_t() : data{type{}} {}
@@ -55,269 +61,82 @@ struct f4x2_pk_t
}
};
struct f6x16_pk_t
template <typename BitType, index_t pk_size>
struct f6_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
using element_type = uint32_t; // element storage fundamental type
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
static constexpr index_t packed_size = pk_size;
static constexpr index_t num_bits_elem = 6;
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
"Packed elements must fit exactly into the element storage.");
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
storage_type data; // packed data
using type = f6_pk_t<BitType, packed_size>;
__host__ __device__ constexpr f6_pk_t() : data{} {}
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
__host__ __device__ f6_pk_t(const T& v) : data{}
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
static_for<0, packed_size, 1>{}(
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
}
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
template <typename T>
__host__ __device__ void pack(const T x, const index_t i)
{
static_assert(is_integral<T>::value || is_same_v<T, BitType>,
"T must be an integral type.");
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
const int bit_pos = i * num_bits_elem;
const int arr_index = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = data.data_[arr_index];
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
data.data_[arr_index] = old_value;
// if it crosses into the next block, shift the remainder
if(overhang > 0 && (arr_index + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
uint32_t next_value = data.data_[arr_index + 1];
next_value |= (bits >> (num_bits_elem - overhang));
data.data_[arr_index + 1] = next_value;
}
}
__host__ __device__ static inline BitType unpack(const type& pk, const index_t i)
{
const int bit_pos = i * num_bits_elem;
const int arr_idx = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
if(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
return static_cast<BitType>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
using f6x32_pk_t = f6_pk_t<f6_t, 32>;
using bf6x16_pk_t = f6_pk_t<bf6_t, 16>;
using bf6x32_pk_t = f6_pk_t<bf6_t, 32>;
// custom data type - pack int4 data
struct pk_i4_t
@@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool, f4_t, f6_t, bf6_t
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value ||
is_same<T, f6_t>::value || is_same<T, bf6_t>::value;
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
}
// scalar_type
@@ -484,6 +302,106 @@ struct scalar_type<bool>
static constexpr index_t vector_size = 1;
};
// Default behavior for types that do not need special handling
template <typename T>
struct packed_type
{
using type = T;
static constexpr index_t packed_size = 1; // number of packed elements
};
template <>
struct packed_type<int4_t>
{
using type = pk_i4_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f4_t>
{
using type = f4x2_pk_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f6_t>
{
using type = f6x32_pk_t;
static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements
};
template <>
struct packed_type<bf6_t>
{
using type = bf6x32_pk_t;
static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements
};
template <typename T>
using packed_type_t = typename packed_type<T>::type;
// Check if the type has packed type specialization
template <typename T>
inline constexpr bool has_packed_type_v = !is_same_v<packed_type_t<T>, T>;
template <typename T>
struct element_type
{
private:
static constexpr auto get_element_type()
{
using U = remove_cvref_t<T>;
if constexpr(is_same_v<U, pk_i4_t>)
return int4_t{};
else if constexpr(is_same_v<U, f4x2_pk_t>)
return f4_t{};
else if constexpr(is_same_v<U, f6x16_pk_t>)
return f6_t{};
else if constexpr(is_same_v<U, bf6x16_pk_t>)
return bf6_t{};
else if constexpr(is_same_v<U, f6x32_pk_t>)
return f6_t{};
else if constexpr(is_same_v<U, bf6x32_pk_t>)
return bf6_t{};
else
return T{};
}
public:
using type = decltype(get_element_type());
};
template <typename T>
using element_type_t = typename element_type<T>::type;
template <typename T>
inline constexpr bool is_packed_type_v =
has_packed_type_v<element_type_t<T>>&& is_same_v<T, packed_type_t<element_type_t<T>>>;
template <typename T>
struct packed_size
{
private:
static constexpr auto get_packed_size()
{
using U = remove_cvref_t<T>;
if constexpr(is_packed_type_v<U>)
return Number<packed_type<element_type_t<U>>::packed_size>{};
else
return Number<packed_type<U>::packed_size>{};
}
public:
using type = decltype(get_packed_size());
static constexpr auto value = get_packed_size();
};
template <typename T>
using packed_size_t = typename packed_size<T>::type;
template <typename T>
inline constexpr index_t packed_size_v = packed_size<T>::value;
#if defined(_WIN32)
using int64_t = long long;
#else

View File

@@ -365,6 +365,88 @@ struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
}
};
template <typename T>
struct vector_type<T, 6, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d3_t __attribute__((ext_vector_type(3)));
typedef T d6_t __attribute__((ext_vector_type(6)));
using type = d6_t;
union
{
d6_t d6_;
StaticallyIndexedArray<d1_t, 6> d1x6_;
StaticallyIndexedArray<d2_t, 3> d2x3_;
StaticallyIndexedArray<d3_t, 2> d3x2_;
StaticallyIndexedArray<d6_t, 1> d6x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x6_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x2_;
}
else if constexpr(is_same<X, d6_t>::value)
{
return data_.d6x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x6_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x2_;
}
else if constexpr(is_same<X, d6_t>::value)
{
return data_.d6x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
{
@@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector<e8m0_bexp_t>
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
using type = f6x16_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
using type = f6x32_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
using type = bf6x16_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
using type = bf6x32_pk_t::storage_type;
};
template <>
@@ -1406,12 +1488,23 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>
struct scalar_type<non_native_vector_base<
T,
N,
ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>>
{
using type = typename non_native_vector_base<T, N>::data_t;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
{
using type = typename non_native_vector_base<T, N>::element_t;
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
@@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x6_t = typename vector_type<int32_t, 6>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;

View File

@@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
@@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return sign ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
@@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
@@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return sign ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}

View File

@@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
@@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
@@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
@@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;