mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop
This commit is contained in:
@@ -360,10 +360,9 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSpaceSize() const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
if constexpr(ck::is_packed_type_v<ck::remove_cvref_t<T>>)
|
||||
{
|
||||
return (mDesc.GetElementSpaceSize() + 1) / 2;
|
||||
return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v<ck::remove_cvref_t<T>>;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -516,69 +515,31 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
}
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v<ck::remove_cvref_t<T>>;
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
@@ -67,6 +67,18 @@ struct GeneratorTensor_1<ck::f8_t>
|
||||
return ck::type_convert<ck::f8_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::bf8_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::bf8_t>(value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
@@ -93,6 +105,38 @@ struct GeneratorTensor_1<ck::f4x2_pk_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f6x32_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::f6_t>(value), static_cast<ck::index_t>(i));
|
||||
});
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::bf6x32_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::bf6_t>(value), static_cast<ck::index_t>(i));
|
||||
});
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
@@ -132,6 +176,44 @@ struct GeneratorTensor_2
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f6x32_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bf6x32_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bhalf_t>
|
||||
{
|
||||
@@ -342,6 +424,46 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f6x32_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float rnd = float(std::rand()) / float(RAND_MAX);
|
||||
float fp32 = min_value + rnd * (max_value - min_value);
|
||||
r.pack(ck::type_convert<ck::f6_t>(fp32), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::bf6x32_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float rnd = float(std::rand()) / float(RAND_MAX);
|
||||
float fp32 = min_value + rnd * (max_value - min_value);
|
||||
r.pack(ck::type_convert<ck::bf6_t>(fp32), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
@@ -360,6 +482,69 @@ struct GeneratorTensor_4
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::f4x2_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float fp32_tmp0 = distribution(generator);
|
||||
float fp32_tmp1 = distribution(generator);
|
||||
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::f6x32_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::f6_t>(distribution(generator)),
|
||||
static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::bf6x32_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::bf6_t>(distribution(generator)),
|
||||
static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <typename... Ts>
|
||||
@@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::f4x2_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::f4x2_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
return ck::type_convert<ck::f4x2_t>(ck::float2_t(tmp));
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::f6x32_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::f6x32_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}(
|
||||
[&](auto i) { r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i)); });
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::bf6x32_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::bf6x32_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}(
|
||||
[&](auto i) { r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i)); });
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, size_t NumEffectiveDim = 2>
|
||||
struct GeneratorTensor_Diagonal
|
||||
{
|
||||
|
||||
@@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
@@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
@@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const bf8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
// XXX: Note on the scale_a and scale_b parameters:
|
||||
// If compiler detects that one or both scales are constant values, it will treat that
|
||||
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
||||
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
||||
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
||||
|
||||
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
||||
// when OPSEL is set otherwise.
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
@@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const bf6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
@@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const bf6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
@@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
@@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
@@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4);
|
||||
using f6_t = _BitInt(6); // e2m3 format
|
||||
using bf6_t = unsigned _BitInt(6); // e3m2 format
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
struct f4x2_pk_t
|
||||
{
|
||||
static constexpr int packed_size = 2;
|
||||
|
||||
using type = uint8_t;
|
||||
type data;
|
||||
__host__ __device__ f4x2_pk_t() : data{type{}} {}
|
||||
@@ -55,269 +61,82 @@ struct f4x2_pk_t
|
||||
}
|
||||
};
|
||||
|
||||
struct f6x16_pk_t
|
||||
template <typename BitType, index_t pk_size>
|
||||
struct f6_pk_t
|
||||
{
|
||||
// store 16 elements of f6_t in an array of 3 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 3>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
|
||||
f6x16_pk_t() : data{type{}} {}
|
||||
f6x16_pk_t(type init) : data{init} {}
|
||||
using element_type = uint32_t; // element storage fundamental type
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline f6_t unpack(Number<I>)
|
||||
static constexpr index_t packed_size = pk_size;
|
||||
static constexpr index_t num_bits_elem = 6;
|
||||
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
|
||||
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
|
||||
"Packed elements must fit exactly into the element storage.");
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
|
||||
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
|
||||
storage_type data; // packed data
|
||||
|
||||
using type = f6_pk_t<BitType, packed_size>;
|
||||
|
||||
__host__ __device__ constexpr f6_pk_t() : data{} {}
|
||||
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
|
||||
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
|
||||
__host__ __device__ f6_pk_t(const T& v) : data{}
|
||||
{
|
||||
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
|
||||
static_for<0, packed_size, 1>{}(
|
||||
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
|
||||
}
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
template <typename T>
|
||||
__host__ __device__ void pack(const T x, const index_t i)
|
||||
{
|
||||
static_assert(is_integral<T>::value || is_same_v<T, BitType>,
|
||||
"T must be an integral type.");
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_index = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = data.data_[arr_index];
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
data.data_[arr_index] = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
uint32_t next_value = data.data_[arr_index + 1];
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
data.data_[arr_index + 1] = next_value;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static inline BitType unpack(const type& pk, const index_t i)
|
||||
{
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<f6_t>(bits & 0x3F);
|
||||
return static_cast<BitType>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 16 f6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 16, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
|
||||
};
|
||||
|
||||
struct f6x32_pk_t
|
||||
{
|
||||
// store 32 elements of f6_t in an array of 6 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 6>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
|
||||
f6x32_pk_t() : data{type{}} {}
|
||||
f6x32_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline f6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<f6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 32 f6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
|
||||
struct bf6x16_pk_t
|
||||
{
|
||||
// store 16 elements of bf6_t in an array of 3 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 3>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
|
||||
bf6x16_pk_t() : data{type{}} {}
|
||||
bf6x16_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline bf6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<bf6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 16 bf6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 16, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
|
||||
struct bf6x32_pk_t
|
||||
{
|
||||
// store 32 elements of bf6_t in an array of 6 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 6>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
|
||||
bf6x32_pk_t() : data{type{}} {}
|
||||
bf6x32_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline bf6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<bf6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 32 bf6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
|
||||
using f6x32_pk_t = f6_pk_t<f6_t, 32>;
|
||||
using bf6x16_pk_t = f6_pk_t<bf6_t, 16>;
|
||||
using bf6x32_pk_t = f6_pk_t<bf6_t, 32>;
|
||||
|
||||
// custom data type - pack int4 data
|
||||
struct pk_i4_t
|
||||
@@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x)
|
||||
}
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
|
||||
// native types: bool, f4_t, f6_t, bf6_t
|
||||
// native types: bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
|
||||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value ||
|
||||
is_same<T, f6_t>::value || is_same<T, bf6_t>::value;
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
|
||||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
@@ -484,6 +302,106 @@ struct scalar_type<bool>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// Default behavior for types that do not need special handling
|
||||
template <typename T>
|
||||
struct packed_type
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t packed_size = 1; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<int4_t>
|
||||
{
|
||||
using type = pk_i4_t;
|
||||
static constexpr index_t packed_size = 2; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<f4_t>
|
||||
{
|
||||
using type = f4x2_pk_t;
|
||||
static constexpr index_t packed_size = 2; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<f6_t>
|
||||
{
|
||||
using type = f6x32_pk_t;
|
||||
static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<bf6_t>
|
||||
{
|
||||
using type = bf6x32_pk_t;
|
||||
static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using packed_type_t = typename packed_type<T>::type;
|
||||
|
||||
// Check if the type has packed type specialization
|
||||
template <typename T>
|
||||
inline constexpr bool has_packed_type_v = !is_same_v<packed_type_t<T>, T>;
|
||||
|
||||
template <typename T>
|
||||
struct element_type
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_element_type()
|
||||
{
|
||||
using U = remove_cvref_t<T>;
|
||||
if constexpr(is_same_v<U, pk_i4_t>)
|
||||
return int4_t{};
|
||||
else if constexpr(is_same_v<U, f4x2_pk_t>)
|
||||
return f4_t{};
|
||||
else if constexpr(is_same_v<U, f6x16_pk_t>)
|
||||
return f6_t{};
|
||||
else if constexpr(is_same_v<U, bf6x16_pk_t>)
|
||||
return bf6_t{};
|
||||
else if constexpr(is_same_v<U, f6x32_pk_t>)
|
||||
return f6_t{};
|
||||
else if constexpr(is_same_v<U, bf6x32_pk_t>)
|
||||
return bf6_t{};
|
||||
else
|
||||
return T{};
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(get_element_type());
|
||||
};
|
||||
template <typename T>
|
||||
using element_type_t = typename element_type<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_packed_type_v =
|
||||
has_packed_type_v<element_type_t<T>>&& is_same_v<T, packed_type_t<element_type_t<T>>>;
|
||||
|
||||
template <typename T>
|
||||
struct packed_size
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_packed_size()
|
||||
{
|
||||
using U = remove_cvref_t<T>;
|
||||
if constexpr(is_packed_type_v<U>)
|
||||
return Number<packed_type<element_type_t<U>>::packed_size>{};
|
||||
else
|
||||
return Number<packed_type<U>::packed_size>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(get_packed_size());
|
||||
static constexpr auto value = get_packed_size();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using packed_size_t = typename packed_size<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr index_t packed_size_v = packed_size<T>::value;
|
||||
|
||||
#if defined(_WIN32)
|
||||
using int64_t = long long;
|
||||
#else
|
||||
|
||||
@@ -365,6 +365,88 @@ struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 6, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d3_t __attribute__((ext_vector_type(3)));
|
||||
typedef T d6_t __attribute__((ext_vector_type(6)));
|
||||
|
||||
using type = d6_t;
|
||||
|
||||
union
|
||||
{
|
||||
d6_t d6_;
|
||||
StaticallyIndexedArray<d1_t, 6> d1x6_;
|
||||
StaticallyIndexedArray<d2_t, 3> d2x3_;
|
||||
StaticallyIndexedArray<d3_t, 2> d3x2_;
|
||||
StaticallyIndexedArray<d6_t, 1> d6x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x6_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d6_t>::value)
|
||||
{
|
||||
return data_.d6x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x6_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d6_t>::value)
|
||||
{
|
||||
return data_.d6x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
@@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector<e8m0_bexp_t>
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x16_pk_t>
|
||||
{
|
||||
using type = f6x16_pk_t::type;
|
||||
using type = f6x16_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x32_pk_t>
|
||||
{
|
||||
using type = f6x32_pk_t::type;
|
||||
using type = f6x32_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf6x16_pk_t>
|
||||
{
|
||||
using type = bf6x16_pk_t::type;
|
||||
using type = bf6x16_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf6x32_pk_t>
|
||||
{
|
||||
using type = bf6x32_pk_t::type;
|
||||
using type = bf6x32_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1406,12 +1488,23 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<non_native_vector_base<T, N>>
|
||||
struct scalar_type<non_native_vector_base<
|
||||
T,
|
||||
N,
|
||||
ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>>
|
||||
{
|
||||
using type = typename non_native_vector_base<T, N>::data_t;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<
|
||||
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
|
||||
{
|
||||
using type = typename non_native_vector_base<T, N>::element_t;
|
||||
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
|
||||
};
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
@@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x6_t = typename vector_type<int32_t, 6>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
|
||||
@@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
|
||||
|
||||
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
|
||||
NumericLimits<f4_t>::DataMinSubnorm())
|
||||
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
return sign ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
|
||||
|
||||
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
|
||||
NumericLimits<f4_t>::DataMinSubnorm())
|
||||
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
return sign ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
|
||||
if(std::isnan(value))
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user