mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop
This commit is contained in:
@@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
|
||||
@@ -360,10 +360,9 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSpaceSize() const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
if constexpr(ck::is_packed_type_v<ck::remove_cvref_t<T>>)
|
||||
{
|
||||
return (mDesc.GetElementSpaceSize() + 1) / 2;
|
||||
return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v<ck::remove_cvref_t<T>>;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -516,69 +515,31 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
}
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v<ck::remove_cvref_t<T>>;
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
@@ -67,6 +67,18 @@ struct GeneratorTensor_1<ck::f8_t>
|
||||
return ck::type_convert<ck::f8_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::bf8_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::bf8_t>(value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
@@ -93,6 +105,38 @@ struct GeneratorTensor_1<ck::f4x2_pk_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f6x32_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::f6_t>(value), static_cast<ck::index_t>(i));
|
||||
});
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::bf6x32_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::bf6_t>(value), static_cast<ck::index_t>(i));
|
||||
});
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
@@ -132,6 +176,44 @@ struct GeneratorTensor_2
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f6x32_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bf6x32_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bhalf_t>
|
||||
{
|
||||
@@ -342,6 +424,46 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f6x32_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float rnd = float(std::rand()) / float(RAND_MAX);
|
||||
float fp32 = min_value + rnd * (max_value - min_value);
|
||||
r.pack(ck::type_convert<ck::f6_t>(fp32), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::bf6x32_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
float rnd = float(std::rand()) / float(RAND_MAX);
|
||||
float fp32 = min_value + rnd * (max_value - min_value);
|
||||
r.pack(ck::type_convert<ck::bf6_t>(fp32), static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
@@ -360,6 +482,69 @@ struct GeneratorTensor_4
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::f4x2_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float fp32_tmp0 = distribution(generator);
|
||||
float fp32_tmp1 = distribution(generator);
|
||||
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::f6x32_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::f6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::f6_t>(distribution(generator)),
|
||||
static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_4<ck::bf6x32_pk_t>
|
||||
{
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf6x32_pk_t operator()(Is...)
|
||||
{
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
r.pack(ck::type_convert<ck::bf6_t>(distribution(generator)),
|
||||
static_cast<ck::index_t>(i));
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <typename... Ts>
|
||||
@@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::f4x2_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::f4x2_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
return ck::type_convert<ck::f4x2_t>(ck::float2_t(tmp));
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::f6x32_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::f6x32_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
|
||||
ck::f6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}(
|
||||
[&](auto i) { r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i)); });
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential<ck::bf6x32_pk_t, Dim>
|
||||
{
|
||||
template <typename... Ts>
|
||||
ck::bf6x32_pk_t operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
|
||||
float tmp = dims[Dim];
|
||||
|
||||
ck::bf6x32_pk_t r;
|
||||
ck::static_for<0, 32, 1>{}(
|
||||
[&](auto i) { r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i)); });
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, size_t NumEffectiveDim = 2>
|
||||
struct GeneratorTensor_Diagonal
|
||||
{
|
||||
|
||||
@@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
@@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
@@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const bf8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
// XXX: Note on the scale_a and scale_b parameters:
|
||||
// If compiler detects that one or both scales are constant values, it will treat that
|
||||
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
||||
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
||||
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
||||
|
||||
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
||||
// when OPSEL is set otherwise.
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
@@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const bf6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
@@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const bf6x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
@@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
@@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
@@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
||||
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
3, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4);
|
||||
using f6_t = _BitInt(6); // e2m3 format
|
||||
using bf6_t = unsigned _BitInt(6); // e3m2 format
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
struct f4x2_pk_t
|
||||
{
|
||||
static constexpr int packed_size = 2;
|
||||
|
||||
using type = uint8_t;
|
||||
type data;
|
||||
__host__ __device__ f4x2_pk_t() : data{type{}} {}
|
||||
@@ -55,269 +61,82 @@ struct f4x2_pk_t
|
||||
}
|
||||
};
|
||||
|
||||
struct f6x16_pk_t
|
||||
template <typename BitType, index_t pk_size>
|
||||
struct f6_pk_t
|
||||
{
|
||||
// store 16 elements of f6_t in an array of 3 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 3>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
|
||||
f6x16_pk_t() : data{type{}} {}
|
||||
f6x16_pk_t(type init) : data{init} {}
|
||||
using element_type = uint32_t; // element storage fundamental type
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline f6_t unpack(Number<I>)
|
||||
static constexpr index_t packed_size = pk_size;
|
||||
static constexpr index_t num_bits_elem = 6;
|
||||
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
|
||||
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
|
||||
"Packed elements must fit exactly into the element storage.");
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
|
||||
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
|
||||
storage_type data; // packed data
|
||||
|
||||
using type = f6_pk_t<BitType, packed_size>;
|
||||
|
||||
__host__ __device__ constexpr f6_pk_t() : data{} {}
|
||||
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
|
||||
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
|
||||
__host__ __device__ f6_pk_t(const T& v) : data{}
|
||||
{
|
||||
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
|
||||
static_for<0, packed_size, 1>{}(
|
||||
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
|
||||
}
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
template <typename T>
|
||||
__host__ __device__ void pack(const T x, const index_t i)
|
||||
{
|
||||
static_assert(is_integral<T>::value || is_same_v<T, BitType>,
|
||||
"T must be an integral type.");
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_index = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = data.data_[arr_index];
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
data.data_[arr_index] = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
uint32_t next_value = data.data_[arr_index + 1];
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
data.data_[arr_index + 1] = next_value;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static inline BitType unpack(const type& pk, const index_t i)
|
||||
{
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<f6_t>(bits & 0x3F);
|
||||
return static_cast<BitType>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 16 f6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 16, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
|
||||
};
|
||||
|
||||
struct f6x32_pk_t
|
||||
{
|
||||
// store 32 elements of f6_t in an array of 6 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 6>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
|
||||
f6x32_pk_t() : data{type{}} {}
|
||||
f6x32_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline f6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<f6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 32 f6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
|
||||
struct bf6x16_pk_t
|
||||
{
|
||||
// store 16 elements of bf6_t in an array of 3 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 3>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
|
||||
bf6x16_pk_t() : data{type{}} {}
|
||||
bf6x16_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline bf6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<bf6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 16 bf6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 16, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 3;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
|
||||
struct bf6x32_pk_t
|
||||
{
|
||||
// store 32 elements of bf6_t in an array of 6 uint32_t
|
||||
using element_type = uint32_t;
|
||||
using type = StaticallyIndexedArray_v2<element_type, 6>;
|
||||
type data;
|
||||
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
|
||||
bf6x32_pk_t() : data{type{}} {}
|
||||
bf6x32_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline bf6_t unpack(Number<I>)
|
||||
{
|
||||
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
|
||||
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = I * num_bits_elem;
|
||||
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<bf6_t>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const test_vec_t& x)
|
||||
{
|
||||
type packed{};
|
||||
|
||||
// for each of the 32 bf6_t values, place its 6 bits in the correct position
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
|
||||
constexpr int num_bits_elem = 6;
|
||||
constexpr int num_bits_vec_elem = 32;
|
||||
constexpr int vector_size = 6;
|
||||
constexpr int bit_pos = i * num_bits_elem;
|
||||
constexpr int arr_index = bit_pos / num_bits_vec_elem;
|
||||
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = packed.At(Number<arr_index>{});
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
packed.At(Number<arr_index>{}) = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = packed.At(Number<arr_index + 1>{});
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
packed.At(Number<arr_index + 1>{}) = next_value;
|
||||
}
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
};
|
||||
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
|
||||
using f6x32_pk_t = f6_pk_t<f6_t, 32>;
|
||||
using bf6x16_pk_t = f6_pk_t<bf6_t, 16>;
|
||||
using bf6x32_pk_t = f6_pk_t<bf6_t, 32>;
|
||||
|
||||
// custom data type - pack int4 data
|
||||
struct pk_i4_t
|
||||
@@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x)
|
||||
}
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
|
||||
// native types: bool, f4_t, f6_t, bf6_t
|
||||
// native types: bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
|
||||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value ||
|
||||
is_same<T, f6_t>::value || is_same<T, bf6_t>::value;
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
|
||||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
@@ -484,6 +302,106 @@ struct scalar_type<bool>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// Default behavior for types that do not need special handling
|
||||
template <typename T>
|
||||
struct packed_type
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t packed_size = 1; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<int4_t>
|
||||
{
|
||||
using type = pk_i4_t;
|
||||
static constexpr index_t packed_size = 2; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<f4_t>
|
||||
{
|
||||
using type = f4x2_pk_t;
|
||||
static constexpr index_t packed_size = 2; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<f6_t>
|
||||
{
|
||||
using type = f6x32_pk_t;
|
||||
static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<bf6_t>
|
||||
{
|
||||
using type = bf6x32_pk_t;
|
||||
static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using packed_type_t = typename packed_type<T>::type;
|
||||
|
||||
// Check if the type has packed type specialization
|
||||
template <typename T>
|
||||
inline constexpr bool has_packed_type_v = !is_same_v<packed_type_t<T>, T>;
|
||||
|
||||
template <typename T>
|
||||
struct element_type
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_element_type()
|
||||
{
|
||||
using U = remove_cvref_t<T>;
|
||||
if constexpr(is_same_v<U, pk_i4_t>)
|
||||
return int4_t{};
|
||||
else if constexpr(is_same_v<U, f4x2_pk_t>)
|
||||
return f4_t{};
|
||||
else if constexpr(is_same_v<U, f6x16_pk_t>)
|
||||
return f6_t{};
|
||||
else if constexpr(is_same_v<U, bf6x16_pk_t>)
|
||||
return bf6_t{};
|
||||
else if constexpr(is_same_v<U, f6x32_pk_t>)
|
||||
return f6_t{};
|
||||
else if constexpr(is_same_v<U, bf6x32_pk_t>)
|
||||
return bf6_t{};
|
||||
else
|
||||
return T{};
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(get_element_type());
|
||||
};
|
||||
template <typename T>
|
||||
using element_type_t = typename element_type<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_packed_type_v =
|
||||
has_packed_type_v<element_type_t<T>>&& is_same_v<T, packed_type_t<element_type_t<T>>>;
|
||||
|
||||
template <typename T>
|
||||
struct packed_size
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_packed_size()
|
||||
{
|
||||
using U = remove_cvref_t<T>;
|
||||
if constexpr(is_packed_type_v<U>)
|
||||
return Number<packed_type<element_type_t<U>>::packed_size>{};
|
||||
else
|
||||
return Number<packed_type<U>::packed_size>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(get_packed_size());
|
||||
static constexpr auto value = get_packed_size();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using packed_size_t = typename packed_size<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr index_t packed_size_v = packed_size<T>::value;
|
||||
|
||||
#if defined(_WIN32)
|
||||
using int64_t = long long;
|
||||
#else
|
||||
|
||||
@@ -365,6 +365,88 @@ struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 6, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d3_t __attribute__((ext_vector_type(3)));
|
||||
typedef T d6_t __attribute__((ext_vector_type(6)));
|
||||
|
||||
using type = d6_t;
|
||||
|
||||
union
|
||||
{
|
||||
d6_t d6_;
|
||||
StaticallyIndexedArray<d1_t, 6> d1x6_;
|
||||
StaticallyIndexedArray<d2_t, 3> d2x3_;
|
||||
StaticallyIndexedArray<d3_t, 2> d3x2_;
|
||||
StaticallyIndexedArray<d6_t, 1> d6x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x6_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d6_t>::value)
|
||||
{
|
||||
return data_.d6x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x6_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d6_t>::value)
|
||||
{
|
||||
return data_.d6x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
@@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector<e8m0_bexp_t>
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x16_pk_t>
|
||||
{
|
||||
using type = f6x16_pk_t::type;
|
||||
using type = f6x16_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x32_pk_t>
|
||||
{
|
||||
using type = f6x32_pk_t::type;
|
||||
using type = f6x32_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf6x16_pk_t>
|
||||
{
|
||||
using type = bf6x16_pk_t::type;
|
||||
using type = bf6x16_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf6x32_pk_t>
|
||||
{
|
||||
using type = bf6x32_pk_t::type;
|
||||
using type = bf6x32_pk_t::storage_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1406,12 +1488,23 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<non_native_vector_base<T, N>>
|
||||
struct scalar_type<non_native_vector_base<
|
||||
T,
|
||||
N,
|
||||
ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>>
|
||||
{
|
||||
using type = typename non_native_vector_base<T, N>::data_t;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<
|
||||
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
|
||||
{
|
||||
using type = typename non_native_vector_base<T, N>::element_t;
|
||||
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
|
||||
};
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
@@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x6_t = typename vector_type<int32_t, 6>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
|
||||
@@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
|
||||
|
||||
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
|
||||
NumericLimits<f4_t>::DataMinSubnorm())
|
||||
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
return sign ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f4_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
|
||||
|
||||
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
|
||||
NumericLimits<f4_t>::DataMinSubnorm())
|
||||
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
return sign ? NumericUtils<f4_t>::negative_zero_mask
|
||||
: NumericUtils<f4_t>::positive_zero_mask;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
}
|
||||
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<f6_t>::data_max_positive_normal_mask;
|
||||
|
||||
@@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
|
||||
if(std::isnan(value))
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
|
||||
|
||||
@@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
|
||||
is_same_v<ADataType, bf6x16_pk_t> ||
|
||||
is_same_v<ADataType, f6x32_pk_t> ||
|
||||
is_same_v<ADataType, bf6x32_pk_t>)
|
||||
{
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
@@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
|
||||
is_same_v<BDataType, bf6x16_pk_t> ||
|
||||
is_same_v<BDataType, f6x32_pk_t> ||
|
||||
is_same_v<BDataType, bf6x32_pk_t>)
|
||||
{
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
@@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
|
||||
is_same_v<ADataType, bf6x16_pk_t> ||
|
||||
is_same_v<ADataType, f6x32_pk_t> ||
|
||||
is_same_v<ADataType, bf6x32_pk_t>)
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) *
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
@@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
|
||||
is_same_v<BDataType, bf6x16_pk_t> ||
|
||||
is_same_v<BDataType, f6x32_pk_t> ||
|
||||
is_same_v<BDataType, bf6x32_pk_t>)
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) *
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
|
||||
using ck::bf6_convert_rne;
|
||||
@@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest)
|
||||
ASSERT_NEAR(max_bf6,
|
||||
type_convert<float>(bf6_convert_rne(std::numeric_limits<float>::infinity())),
|
||||
0.0f);
|
||||
|
||||
// convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6
|
||||
ASSERT_NEAR(-max_bf6, type_convert<float>(bf6_convert_rne(-30.0f)), 0.0f);
|
||||
ASSERT_NEAR(max_bf6, type_convert<float>(bf6_convert_rne(30.0f)), 0.0f);
|
||||
|
||||
// convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0
|
||||
float less_than_subnorm = 0.03125f;
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(bf6_convert_rne(less_than_subnorm)), 0.0f);
|
||||
@@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1)
|
||||
vector_type<bf6x16_pk_t, vector_size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
0);
|
||||
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i), 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
|
||||
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i),
|
||||
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2)
|
||||
// check default CTOR
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
|
||||
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
|
||||
.template unpack<>(Number<idx_element>{}),
|
||||
0);
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
|
||||
0);
|
||||
});
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
|
||||
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec[i]};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
|
||||
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
|
||||
.template unpack<>(Number<idx_element>{}),
|
||||
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
|
||||
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1)
|
||||
vector_type<bf6x32_pk_t, vector_size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
0);
|
||||
ASSERT_EQ(right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i), 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
|
||||
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{test_vec};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
|
||||
ASSERT_EQ(left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i),
|
||||
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(BF6, TestAllValues)
|
||||
{
|
||||
|
||||
constexpr std::array<float, 64> e3m2ValuesOCP = {
|
||||
// clang-format off
|
||||
0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000,
|
||||
0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000,
|
||||
0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
|
||||
1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000,
|
||||
2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000,
|
||||
4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000,
|
||||
8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000,
|
||||
16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000,
|
||||
-0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000,
|
||||
-0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000,
|
||||
-0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
|
||||
-1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000,
|
||||
-2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000,
|
||||
-4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000,
|
||||
-8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000,
|
||||
-16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
constexpr uint8_t e3m2BitsOCP[] = {
|
||||
// clang-format off
|
||||
0b000000, 0b000001, 0b000010, 0b000011,
|
||||
0b000100, 0b000101, 0b000110, 0b000111,
|
||||
0b001000, 0b001001, 0b001010, 0b001011,
|
||||
0b001100, 0b001101, 0b001110, 0b001111,
|
||||
0b010000, 0b010001, 0b010010, 0b010011,
|
||||
0b010100, 0b010101, 0b010110, 0b010111,
|
||||
0b011000, 0b011001, 0b011010, 0b011011,
|
||||
0b011100, 0b011101, 0b011110, 0b011111,
|
||||
0b100000, 0b100001, 0b100010, 0b100011,
|
||||
0b100100, 0b100101, 0b100110, 0b100111,
|
||||
0b101000, 0b101001, 0b101010, 0b101011,
|
||||
0b101100, 0b101101, 0b101110, 0b101111,
|
||||
0b110000, 0b110001, 0b110010, 0b110011,
|
||||
0b110100, 0b110101, 0b110110, 0b110111,
|
||||
0b111000, 0b111001, 0b111010, 0b111011,
|
||||
0b111100, 0b111101, 0b111110, 0b111111
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
|
||||
|
||||
if(ck_logging)
|
||||
printf("BF6 Table\n");
|
||||
ck::static_for<0, 64, 1>{}([&](auto i) {
|
||||
float fp = type_convert<float>(bf6_t(e3m2BitsOCP[i]));
|
||||
ASSERT_EQ(fp, e3m2ValuesOCP[i]);
|
||||
|
||||
bf6_t bf6 = type_convert<bf6_t>(e3m2ValuesOCP[i]);
|
||||
ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F);
|
||||
|
||||
if(ck_logging)
|
||||
{
|
||||
// Print the binary representation
|
||||
printf("Bits: 0b");
|
||||
for(int j = 5; j >= 0; --j)
|
||||
{
|
||||
printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0');
|
||||
}
|
||||
printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f4_convert_rne;
|
||||
@@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest)
|
||||
// convert maximal float to fp4 and back, check if clipped to 6.0
|
||||
ASSERT_NEAR(
|
||||
max_fp4, type_convert<float>(f4_convert_rne(std::numeric_limits<float>::max())), abs_tol);
|
||||
|
||||
// convert +/-7.0 to fp4 and back, check if clipped to +/-6.0
|
||||
ASSERT_NEAR(-max_fp4, type_convert<float>(f4_convert_rne(-7.0f)), 0.0);
|
||||
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_rne(7.0f)), 0.0);
|
||||
|
||||
// positive norm float value to fp4 and back, check if holds
|
||||
float pos_float = 1.0f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
|
||||
@@ -468,3 +474,54 @@ TEST(FP4, TestAsType32)
|
||||
test_vec.at(i + 1));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(FP4, TestAllValues)
|
||||
{
|
||||
constexpr std::array<float, 16> e2m1ValuesOCP = {
|
||||
// clang-format off
|
||||
0.0000000000, 0.5000000000,
|
||||
1.0000000000, 1.5000000000,
|
||||
2.0000000000, 3.0000000000,
|
||||
4.0000000000, 6.0000000000,
|
||||
-0.0000000000, -0.5000000000,
|
||||
-1.0000000000, -1.5000000000,
|
||||
-2.0000000000, -3.0000000000,
|
||||
-4.0000000000, -6.0000000000
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
constexpr uint8_t e2m1BitsOCP[] = {
|
||||
// clang-format off
|
||||
0b0000, 0b0001,
|
||||
0b0010, 0b0011,
|
||||
0b0100, 0b0101,
|
||||
0b0110, 0b0111,
|
||||
0b1000, 0b1001,
|
||||
0b1010, 0b1011,
|
||||
0b1100, 0b1101,
|
||||
0b1110, 0b1111
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
|
||||
|
||||
if(ck_logging)
|
||||
printf("FP4 Table\n");
|
||||
ck::static_for<0, 16, 1>{}([&](auto i) {
|
||||
float fp = type_convert<float>(f4_t(e2m1BitsOCP[i]));
|
||||
ASSERT_EQ(fp, e2m1ValuesOCP[i]);
|
||||
|
||||
f4_t fp4 = type_convert<f4_t>(e2m1ValuesOCP[i]);
|
||||
ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF);
|
||||
if(ck_logging)
|
||||
{
|
||||
// Print the binary representation
|
||||
printf("Bits: 0b");
|
||||
for(int j = 3; j >= 0; --j)
|
||||
{
|
||||
printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0');
|
||||
}
|
||||
printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
@@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest)
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_rne(0.0f)), 0.0f);
|
||||
// convert maximal f6_t to float and check if equal to max_fp6
|
||||
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(max_fp6)), 0.0f);
|
||||
|
||||
// convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6
|
||||
ASSERT_NEAR(-max_fp6, type_convert<float>(f6_convert_rne(-8.0f)), 0.0f);
|
||||
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(8.0f)), 0.0f);
|
||||
|
||||
// convert maximal float to fp6 and back, check if clipped to max_fp6
|
||||
ASSERT_NEAR(
|
||||
max_fp6, type_convert<float>(f6_convert_rne(std::numeric_limits<float>::max())), 0.0f);
|
||||
@@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1)
|
||||
vector_type<f6x16_pk_t, vector_size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
|
||||
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i), 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec};
|
||||
});
|
||||
|
||||
// copy the vector
|
||||
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
|
||||
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i),
|
||||
static_cast<f6_t>(test_vec[static_cast<int>(i)]))
|
||||
<< " i = " << i << "; left = "
|
||||
<< type_convert<float>(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i))
|
||||
<< " -- right = "
|
||||
<< type_convert<float>(static_cast<f6_t>(test_vec[static_cast<int>(i)])) << " ("
|
||||
<< static_cast<int>(test_vec[static_cast<int>(i)]) << ")" << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2)
|
||||
// check default CTOR
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
|
||||
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
|
||||
.template unpack<>(Number<idx_element>{}),
|
||||
0);
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
|
||||
0);
|
||||
});
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
|
||||
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec[i]};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
|
||||
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
|
||||
.template unpack<>(Number<idx_element>{}),
|
||||
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
|
||||
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1)
|
||||
vector_type<f6x32_pk_t, vector_size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
|
||||
ASSERT_EQ(right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i), 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, vector_size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
|
||||
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{test_vec};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, packed_size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(
|
||||
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
|
||||
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
|
||||
ASSERT_EQ(left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i),
|
||||
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(FP6, TestAllValues)
|
||||
{
|
||||
constexpr std::array<float, 64> e2m3ValuesOCP = {
|
||||
// clang-format off
|
||||
0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
|
||||
1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000,
|
||||
2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000,
|
||||
4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000,
|
||||
-0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
|
||||
-1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000,
|
||||
-2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000,
|
||||
-4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
constexpr uint8_t e2m3BitsOCP[] = {
|
||||
// clang-format off
|
||||
0b000000, 0b000001, 0b000010, 0b000011,
|
||||
0b000100, 0b000101, 0b000110, 0b000111,
|
||||
0b001000, 0b001001, 0b001010, 0b001011,
|
||||
0b001100, 0b001101, 0b001110, 0b001111,
|
||||
0b010000, 0b010001, 0b010010, 0b010011,
|
||||
0b010100, 0b010101, 0b010110, 0b010111,
|
||||
0b011000, 0b011001, 0b011010, 0b011011,
|
||||
0b011100, 0b011101, 0b011110, 0b011111,
|
||||
0b100000, 0b100001, 0b100010, 0b100011,
|
||||
0b100100, 0b100101, 0b100110, 0b100111,
|
||||
0b101000, 0b101001, 0b101010, 0b101011,
|
||||
0b101100, 0b101101, 0b101110, 0b101111,
|
||||
0b110000, 0b110001, 0b110010, 0b110011,
|
||||
0b110100, 0b110101, 0b110110, 0b110111,
|
||||
0b111000, 0b111001, 0b111010, 0b111011,
|
||||
0b111100, 0b111101, 0b111110, 0b111111
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
|
||||
|
||||
if(ck_logging)
|
||||
printf("FP6 Table\n");
|
||||
ck::static_for<0, 64, 1>{}([&](auto i) {
|
||||
float fp = type_convert<float>(f6_t(e2m3BitsOCP[i]));
|
||||
ASSERT_EQ(fp, e2m3ValuesOCP[i]);
|
||||
|
||||
f6_t fp6 = type_convert<f6_t>(e2m3ValuesOCP[i]);
|
||||
ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F);
|
||||
|
||||
if(ck_logging)
|
||||
{
|
||||
// Print the binary representation
|
||||
printf("Bits: 0b");
|
||||
for(int j = 5; j >= 0; --j)
|
||||
{
|
||||
printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0');
|
||||
}
|
||||
printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,9 +5,12 @@
|
||||
|
||||
#include "mx_mfma_op.hpp"
|
||||
|
||||
using ck::bf6_t;
|
||||
using ck::bf8_t;
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f4_t;
|
||||
using ck::f4x2_pk_t;
|
||||
using ck::f6_t;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
@@ -17,13 +20,15 @@ using ck::type_convert;
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_km_kn_nm_test(ck::index_t init)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_test(ck::index_t init)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using CPUAccType = AccType;
|
||||
|
||||
@@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init)
|
||||
return pass;
|
||||
}
|
||||
|
||||
const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations
|
||||
|
||||
TEST(MFMA, FP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f8_t,
|
||||
f8_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
TEST(MFMA, BF8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf8_t,
|
||||
bf8_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Run the test for the given MFMA instruction
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_mk_kn_mn_test(ck::index_t init)
|
||||
TEST(MFMA, FP4MFMA16x16x128)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using CPUAccType = AccType;
|
||||
|
||||
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
|
||||
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
|
||||
const auto mfma_kernel = ck::
|
||||
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
CPUAccType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K>{}(mfma_kernel, init);
|
||||
|
||||
return pass;
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass =
|
||||
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP4MFMA16x16x128)
|
||||
TEST(MFMA, FP6MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
|
||||
AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass =
|
||||
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, BF6MFMA16x16x128)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf6_t,
|
||||
bf6_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass =
|
||||
run_mfma_test<ALayout, BLayout, CLayout, f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, BF8MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf8_t,
|
||||
bf8_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
|
||||
AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass =
|
||||
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP6MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass =
|
||||
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, BF6MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf6_t,
|
||||
bf6_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64)
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mxmfma_test(ck::index_t init)
|
||||
{
|
||||
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
|
||||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
|
||||
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported");
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using ScaleType = ck::e8m0_bexp_t; // biased exponent type
|
||||
@@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 5;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f8_t,
|
||||
f8_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 5;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f8_t,
|
||||
f8_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXBF8MFMA16x16x128)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf8_t,
|
||||
bf8_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXBF8MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf8_t,
|
||||
bf8_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP6MFMA16x16x128)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f6_t,
|
||||
f6_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP6MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f6_t,
|
||||
f6_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXBF6MFMA16x16x128)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf6_t,
|
||||
bf6_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXBF6MFMA32x32x64)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
bf6_t,
|
||||
bf6_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(
|
||||
AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f4_t,
|
||||
f4_t,
|
||||
float,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(
|
||||
AB_init);
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto AB_init = (common_init < 0) ? 5 : common_init;
|
||||
auto pass = run_mxmfma_test<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
f4_t,
|
||||
f4_t,
|
||||
half_t,
|
||||
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
||||
// clang-format on
|
||||
|
||||
static_assert(!is_packed_type_v<AType>, "Packed type is not supported");
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
|
||||
// Here we want to load from rows of A in chunks of 16 elements each.
|
||||
@@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
|
||||
|
||||
|
||||
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
|
||||
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
||||
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
|
||||
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
|
||||
|
||||
// clang-format on
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
|
||||
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
|
||||
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
|
||||
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
|
||||
|
||||
constexpr index_t num_chunks = is_packed_type_v<AType> ? 1 : 2;
|
||||
|
||||
// Here we want to load from rows of A in chunks of 16 elements each.
|
||||
static constexpr uint32_t chunk_size = 16;
|
||||
constexpr uint32_t chunk_size = is_packed_type_v<AType> ? 32 : 16;
|
||||
|
||||
// each chunk is separated by offset
|
||||
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
|
||||
@@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 32 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D =
|
||||
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
|
||||
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
|
||||
// FP8/6/4 Row {0-31} | {0-15}
|
||||
// FP8 Col {0, 16} | {0, 16, 32, 48}
|
||||
// FP6/4 Col {0, 32} | {0, 32, 64, 96}
|
||||
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size);
|
||||
|
||||
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
|
||||
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
|
||||
|
||||
// Flatten to 1D row_major offsets.
|
||||
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
||||
|
||||
// BLOCK_K is a stride in A matrix
|
||||
auto startOffset = row_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
row_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarChunkT = vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
|
||||
|
||||
union
|
||||
{
|
||||
AFragT frag;
|
||||
AScalarFragT chunks[num_chunks];
|
||||
AScalarChunkT chunks[num_chunks];
|
||||
} fragA{};
|
||||
|
||||
const AScalarFragT* fragPtr;
|
||||
const AScalarChunkT* fragPtr;
|
||||
|
||||
// BLOCK_K is a stride in A matrix
|
||||
auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v<AType>;
|
||||
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v<AType>;
|
||||
|
||||
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
|
||||
{
|
||||
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset +
|
||||
chunk_idx * kMajorOffset);
|
||||
fragPtr = reinterpret_cast<AScalarChunkT const*>(input_ptr + startOffset +
|
||||
chunk_idx * kMajorOffset);
|
||||
fragA.chunks[chunk_idx] = *fragPtr;
|
||||
}
|
||||
|
||||
@@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
|
||||
|
||||
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
|
||||
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
|
||||
|
||||
// clang-format on
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
|
||||
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
|
||||
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
|
||||
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
|
||||
|
||||
constexpr index_t num_chunks = is_packed_type_v<BType> ? 1 : 2;
|
||||
|
||||
// Here we want to load from cols of B in chunks of 16 elements each.
|
||||
static constexpr uint32_t chunk_size = 16;
|
||||
constexpr uint32_t chunk_size = is_packed_type_v<BType> ? 32 : 16;
|
||||
|
||||
// each chunk is separated by an offset
|
||||
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64
|
||||
@@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 32 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D =
|
||||
std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48}
|
||||
threadIdx.x % BLOCK_N); // Col {0-31} | {0-15}
|
||||
// FP8/6/4 Col {0-31} | {0-15}
|
||||
// FP8 Row {0, 16} | {0, 16, 32, 48}
|
||||
// FP6/4 Row {0, 32} | {0, 32, 64, 96}
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N);
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
// auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols
|
||||
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
|
||||
|
||||
// BLOCK_K is a stride in B matrix
|
||||
auto startOffset = col_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
col_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
|
||||
using BRawT = typename scalar_type<BFragT>::type;
|
||||
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
using BRawT = typename scalar_type<BFragT>::type;
|
||||
using BScalarChunkT = vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
|
||||
|
||||
union
|
||||
{
|
||||
BFragT frag;
|
||||
BScalarFragT chunks[num_chunks];
|
||||
BScalarChunkT chunks[num_chunks];
|
||||
} fragB{};
|
||||
|
||||
const BScalarFragT* fragPtr;
|
||||
const BScalarChunkT* fragPtr;
|
||||
|
||||
for(index_t chunk = 0; chunk < num_chunks; chunk++)
|
||||
// BLOCK_K is a stride in B matrix
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v<BType>;
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v<BType>;
|
||||
|
||||
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
|
||||
{
|
||||
fragPtr =
|
||||
reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + chunk * kMajorOffset);
|
||||
fragB.chunks[chunk] = *fragPtr;
|
||||
fragPtr = reinterpret_cast<BScalarChunkT const*>(input_ptr + startOffset +
|
||||
chunk_idx * kMajorOffset);
|
||||
fragB.chunks[chunk_idx] = *fragPtr;
|
||||
}
|
||||
|
||||
return fragB.frag;
|
||||
@@ -904,20 +921,22 @@ template <typename AType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
__global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
__global__ void matmul(const typename packed_type<AType>::type* a,
|
||||
const typename packed_type<BType>::type* b,
|
||||
CType* c)
|
||||
{
|
||||
using PackedAType = typename packed_type<AType>::type;
|
||||
constexpr auto packed_size_a = packed_type<AType>::packed_size;
|
||||
using PackedBType = typename packed_type<BType>::type;
|
||||
constexpr auto packed_size_b = packed_type<BType>::packed_size;
|
||||
|
||||
constexpr int WAVE_SIZE = 64;
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
@@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
// Load the inputs.
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
fragA = load_A_row_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
}
|
||||
else
|
||||
{
|
||||
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
fragA = load_A_col_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
}
|
||||
else
|
||||
{
|
||||
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
fragB = load_B_col_major<PackedBType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
}
|
||||
|
||||
// Matrix multiply-accumulate using MFMA units
|
||||
@@ -979,21 +998,24 @@ template <typename AType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
__global__ void
|
||||
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
|
||||
__global__ void matmul(const packed_type_t<AType>* a,
|
||||
const ScaleType* xa,
|
||||
const packed_type_t<BType>* b,
|
||||
const ScaleType* xb,
|
||||
CType* c)
|
||||
{
|
||||
using PackedAType = packed_type_t<AType>;
|
||||
constexpr auto packed_size_a = packed_size_v<AType>;
|
||||
using PackedBType = packed_type_t<BType>;
|
||||
constexpr auto packed_size_b = packed_size_v<BType>;
|
||||
|
||||
constexpr int WAVE_SIZE = 64;
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
@@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
// Load the inputs.
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
fragA =
|
||||
load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
||||
a, xa, fragXa);
|
||||
fragA = load_mx_A_row_major<PackedAType,
|
||||
AFragT,
|
||||
ScaleType,
|
||||
AScaleFragT,
|
||||
BLOCK_M,
|
||||
BLOCK_K,
|
||||
BLOCK_X>(a, xa, fragXa);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
}
|
||||
else
|
||||
{
|
||||
fragB =
|
||||
load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
||||
b, xb, fragXb);
|
||||
fragB = load_mx_B_col_major<PackedBType,
|
||||
BFragT,
|
||||
ScaleType,
|
||||
BScaleFragT,
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
BLOCK_X>(b, xb, fragXb);
|
||||
}
|
||||
|
||||
// Scaled Matrix multiply-accumulate using MFMA units
|
||||
@@ -1151,6 +1181,11 @@ template <typename DeviceMFMA,
|
||||
index_t BLOCK_X>
|
||||
struct TestMXMFMA
|
||||
{
|
||||
using PackedAType = typename packed_type<ADataType>::type;
|
||||
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
|
||||
using PackedBType = typename packed_type<BDataType>::type;
|
||||
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
|
||||
|
||||
auto PrepareGemmTensors(const GemmParams& params, index_t init)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
@@ -1167,11 +1202,11 @@ struct TestMXMFMA
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(
|
||||
Tensor<PackedAType> a_m_k(
|
||||
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
||||
Tensor<ScaleType> a_scales(
|
||||
f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{}));
|
||||
Tensor<BDataType> b_n_k(
|
||||
Tensor<PackedBType> b_n_k(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<ScaleType> b_scales(
|
||||
f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{}));
|
||||
@@ -1183,51 +1218,44 @@ struct TestMXMFMA
|
||||
switch(init)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.5f}});
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
||||
break;
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
|
||||
break;
|
||||
case 2:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-2.0, 2.0});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-2.0, 2.0});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
||||
break;
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(0, 1, time(nullptr)));
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(0, 1));
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(0, 1, time(nullptr) / 2));
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
|
||||
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7}); // Z[-6,6]
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
|
||||
|
||||
@@ -1272,9 +1300,9 @@ struct TestMXMFMA
|
||||
|
||||
auto host_tensors = PrepareGemmTensors(params, init);
|
||||
|
||||
const Tensor<ADataType>& a = std::get<0>(host_tensors);
|
||||
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
|
||||
const Tensor<ScaleType>& a_scales = std::get<1>(host_tensors);
|
||||
const Tensor<BDataType>& b = std::get<2>(host_tensors);
|
||||
const Tensor<PackedBType>& b = std::get<2>(host_tensors);
|
||||
const Tensor<ScaleType>& b_scales = std::get<3>(host_tensors);
|
||||
Tensor<CDataType>& c_host = std::get<4>(host_tensors);
|
||||
Tensor<CDataType>& c_device = std::get<5>(host_tensors);
|
||||
@@ -1356,6 +1384,12 @@ template <typename DeviceMFMA,
|
||||
index_t BLOCK_K>
|
||||
struct TestMFMA
|
||||
{
|
||||
|
||||
using PackedAType = typename packed_type<ADataType>::type;
|
||||
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
|
||||
using PackedBType = typename packed_type<BDataType>::type;
|
||||
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
|
||||
|
||||
auto PrepareGemmTensors(const GemmParams& params, index_t init)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
@@ -1372,9 +1406,9 @@ struct TestMFMA
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(
|
||||
Tensor<PackedAType> a_m_k(
|
||||
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_n_k(
|
||||
Tensor<PackedBType> b_n_k(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
@@ -1384,34 +1418,30 @@ struct TestMFMA
|
||||
switch(init)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{0.625f});
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
|
||||
break;
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
|
||||
break;
|
||||
case 2:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
|
||||
// expect small round off errors that lead to FP8MFMA32x32x64 failures
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-5, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-5, 5});
|
||||
break;
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
|
||||
break;
|
||||
case 4:
|
||||
// FP4 values case
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5});
|
||||
// expect small round off errors that lead to FP8MFMA32x32x64 failures
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(-1, 3));
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(1, 3));
|
||||
break;
|
||||
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
|
||||
// all initial values are representable in FP8/6, BF8/6 FP4 is missing 5
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7});
|
||||
|
||||
break;
|
||||
}
|
||||
@@ -1453,10 +1483,10 @@ struct TestMFMA
|
||||
|
||||
auto host_tensors = PrepareGemmTensors(params, init);
|
||||
|
||||
const Tensor<ADataType>& a = std::get<0>(host_tensors);
|
||||
const Tensor<BDataType>& b = std::get<1>(host_tensors);
|
||||
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
|
||||
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
|
||||
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
|
||||
const Tensor<PackedBType>& b = std::get<1>(host_tensors);
|
||||
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
|
||||
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
@@ -1464,8 +1494,8 @@ struct TestMFMA
|
||||
auto b_element_op = PassThrough{};
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<PackedAType,
|
||||
PackedBType,
|
||||
CDataType,
|
||||
CPUAccDataType,
|
||||
PassThrough,
|
||||
|
||||
Reference in New Issue
Block a user