Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-19 22:06:56 +00:00
parent 0b87df9c4a
commit 9d088bc569
15 changed files with 1602 additions and 588 deletions

View File

@@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
d_tensors_device.resize(group_count); // reserve and update vector size
std::size_t flop = 0, num_btype = 0;

View File

@@ -360,10 +360,9 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
if constexpr(ck::is_packed_type_v<ck::remove_cvref_t<T>>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v<ck::remove_cvref_t<T>>;
}
else
{
@@ -516,69 +515,31 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v<ck::remove_cvref_t<T>>;
}
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
ck::packed_size_v<ck::remove_cvref_t<T>>];
}
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
return mData[mDesc.GetOffsetFromMultiIndex(is...) /
ck::packed_size_v<ck::remove_cvref_t<T>>];
}
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
typename Data::iterator begin() { return mData.begin(); }

View File

@@ -67,6 +67,18 @@ struct GeneratorTensor_1<ck::f8_t>
return ck::type_convert<ck::f8_t>(value);
}
};
template <>
struct GeneratorTensor_1<ck::bf8_t>
{
float value = 1.0;
template <typename... Is>
ck::bf8_t operator()(Is...)
{
return ck::type_convert<ck::bf8_t>(value);
}
};
#endif
template <>
@@ -93,6 +105,38 @@ struct GeneratorTensor_1<ck::f4x2_pk_t>
}
};
template <>
struct GeneratorTensor_1<ck::f6x32_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::f6_t>(value), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_1<ck::bf6x32_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::bf6_t>(value), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
@@ -132,6 +176,44 @@ struct GeneratorTensor_2
}
};
template <>
struct GeneratorTensor_2<ck::f6x32_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float tmp = (std::rand() % (max_value - min_value)) + min_value;
r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_2<ck::bf6x32_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float tmp = (std::rand() % (max_value - min_value)) + min_value;
r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_2<ck::bhalf_t>
{
@@ -342,6 +424,46 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
}
};
template <>
struct GeneratorTensor_3<ck::f6x32_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float rnd = float(std::rand()) / float(RAND_MAX);
float fp32 = min_value + rnd * (max_value - min_value);
r.pack(ck::type_convert<ck::f6_t>(fp32), static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_3<ck::bf6x32_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
float rnd = float(std::rand()) / float(RAND_MAX);
float fp32 = min_value + rnd * (max_value - min_value);
r.pack(ck::type_convert<ck::bf6_t>(fp32), static_cast<ck::index_t>(i));
});
return r;
}
};
template <typename T>
struct GeneratorTensor_4
{
@@ -360,6 +482,69 @@ struct GeneratorTensor_4
}
};
template <>
struct GeneratorTensor_4<ck::f4x2_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float fp32_tmp0 = distribution(generator);
float fp32_tmp1 = distribution(generator);
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
}
};
template <>
struct GeneratorTensor_4<ck::f6x32_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::f6x32_pk_t operator()(Is...)
{
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::f6_t>(distribution(generator)),
static_cast<ck::index_t>(i));
});
return r;
}
};
template <>
struct GeneratorTensor_4<ck::bf6x32_pk_t>
{
std::mt19937 generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is>
ck::bf6x32_pk_t operator()(Is...)
{
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}([&](auto i) {
r.pack(ck::type_convert<ck::bf6_t>(distribution(generator)),
static_cast<ck::index_t>(i));
});
return r;
}
};
struct GeneratorTensor_Checkboard
{
template <typename... Ts>
@@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::f4x2_pk_t, Dim>
{
template <typename... Ts>
ck::f4x2_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
return ck::type_convert<ck::f4x2_t>(ck::float2_t(tmp));
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::f6x32_pk_t, Dim>
{
template <typename... Ts>
ck::f6x32_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
ck::f6x32_pk_t r;
ck::static_for<0, 32, 1>{}(
[&](auto i) { r.pack(ck::type_convert<ck::f6_t>(tmp), static_cast<ck::index_t>(i)); });
return r;
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::bf6x32_pk_t, Dim>
{
template <typename... Ts>
ck::bf6x32_pk_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
float tmp = dims[Dim];
ck::bf6x32_pk_t r;
ck::static_for<0, 32, 1>{}(
[&](auto i) { r.pack(ck::type_convert<ck::bf6_t>(tmp), static_cast<ck::index_t>(i)); });
return r;
}
};
template <typename T, size_t NumEffectiveDim = 2>
struct GeneratorTensor_Diagonal
{

View File

@@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
@@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
@@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
@@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const bf8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
// constant as F32 constant. I.e., if scale_a at some point was declared as
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
// when OPSEL is set otherwise.
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
@@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a,
const int32_t scale_a,
const f6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a,
const int32_t scale_a,
const bf6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
@@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
4, // blgp
0, // OPSEL
scale_a,
@@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a,
const int32_t scale_a,
const f6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a,
const int32_t scale_a,
const bf6x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
@@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
@@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
@@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
@@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};

View File

@@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format
// scalar_type
template <typename TV>
struct scalar_type;
struct f4x2_pk_t
{
static constexpr int packed_size = 2;
using type = uint8_t;
type data;
__host__ __device__ f4x2_pk_t() : data{type{}} {}
@@ -55,269 +61,82 @@ struct f4x2_pk_t
}
};
struct f6x16_pk_t
template <typename BitType, index_t pk_size>
struct f6_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
using element_type = uint32_t; // element storage fundamental type
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
static constexpr index_t packed_size = pk_size;
static constexpr index_t num_bits_elem = 6;
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
"Packed elements must fit exactly into the element storage.");
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
storage_type data; // packed data
using type = f6_pk_t<BitType, packed_size>;
__host__ __device__ constexpr f6_pk_t() : data{} {}
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
__host__ __device__ f6_pk_t(const T& v) : data{}
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
static_for<0, packed_size, 1>{}(
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
}
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
template <typename T>
__host__ __device__ void pack(const T x, const index_t i)
{
static_assert(is_integral<T>::value || is_same_v<T, BitType>,
"T must be an integral type.");
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
const int bit_pos = i * num_bits_elem;
const int arr_index = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = data.data_[arr_index];
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
data.data_[arr_index] = old_value;
// if it crosses into the next block, shift the remainder
if(overhang > 0 && (arr_index + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
uint32_t next_value = data.data_[arr_index + 1];
next_value |= (bits >> (num_bits_elem - overhang));
data.data_[arr_index + 1] = next_value;
}
}
__host__ __device__ static inline BitType unpack(const type& pk, const index_t i)
{
const int bit_pos = i * num_bits_elem;
const int arr_idx = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
if(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
return static_cast<BitType>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
using f6x32_pk_t = f6_pk_t<f6_t, 32>;
using bf6x16_pk_t = f6_pk_t<bf6_t, 16>;
using bf6x32_pk_t = f6_pk_t<bf6_t, 32>;
// custom data type - pack int4 data
struct pk_i4_t
@@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool, f4_t, f6_t, bf6_t
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value ||
is_same<T, f6_t>::value || is_same<T, bf6_t>::value;
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
}
// scalar_type
@@ -484,6 +302,106 @@ struct scalar_type<bool>
static constexpr index_t vector_size = 1;
};
// Default behavior for types that do not need special handling
template <typename T>
struct packed_type
{
using type = T;
static constexpr index_t packed_size = 1; // number of packed elements
};
template <>
struct packed_type<int4_t>
{
using type = pk_i4_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f4_t>
{
using type = f4x2_pk_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f6_t>
{
using type = f6x32_pk_t;
static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements
};
template <>
struct packed_type<bf6_t>
{
using type = bf6x32_pk_t;
static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements
};
template <typename T>
using packed_type_t = typename packed_type<T>::type;
// Check if the type has packed type specialization
template <typename T>
inline constexpr bool has_packed_type_v = !is_same_v<packed_type_t<T>, T>;
template <typename T>
struct element_type
{
private:
static constexpr auto get_element_type()
{
using U = remove_cvref_t<T>;
if constexpr(is_same_v<U, pk_i4_t>)
return int4_t{};
else if constexpr(is_same_v<U, f4x2_pk_t>)
return f4_t{};
else if constexpr(is_same_v<U, f6x16_pk_t>)
return f6_t{};
else if constexpr(is_same_v<U, bf6x16_pk_t>)
return bf6_t{};
else if constexpr(is_same_v<U, f6x32_pk_t>)
return f6_t{};
else if constexpr(is_same_v<U, bf6x32_pk_t>)
return bf6_t{};
else
return T{};
}
public:
using type = decltype(get_element_type());
};
template <typename T>
using element_type_t = typename element_type<T>::type;
template <typename T>
inline constexpr bool is_packed_type_v =
has_packed_type_v<element_type_t<T>>&& is_same_v<T, packed_type_t<element_type_t<T>>>;
template <typename T>
struct packed_size
{
private:
static constexpr auto get_packed_size()
{
using U = remove_cvref_t<T>;
if constexpr(is_packed_type_v<U>)
return Number<packed_type<element_type_t<U>>::packed_size>{};
else
return Number<packed_type<U>::packed_size>{};
}
public:
using type = decltype(get_packed_size());
static constexpr auto value = get_packed_size();
};
template <typename T>
using packed_size_t = typename packed_size<T>::type;
template <typename T>
inline constexpr index_t packed_size_v = packed_size<T>::value;
#if defined(_WIN32)
using int64_t = long long;
#else

View File

@@ -365,6 +365,88 @@ struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
}
};
template <typename T>
struct vector_type<T, 6, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d3_t __attribute__((ext_vector_type(3)));
typedef T d6_t __attribute__((ext_vector_type(6)));
using type = d6_t;
union
{
d6_t d6_;
StaticallyIndexedArray<d1_t, 6> d1x6_;
StaticallyIndexedArray<d2_t, 3> d2x3_;
StaticallyIndexedArray<d3_t, 2> d3x2_;
StaticallyIndexedArray<d6_t, 1> d6x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x6_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x2_;
}
else if constexpr(is_same<X, d6_t>::value)
{
return data_.d6x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d3_t>::value || is_same<X, d6_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x6_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x2_;
}
else if constexpr(is_same<X, d6_t>::value)
{
return data_.d6x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
{
@@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector<e8m0_bexp_t>
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
using type = f6x16_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
using type = f6x32_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
using type = bf6x16_pk_t::storage_type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
using type = bf6x32_pk_t::storage_type;
};
template <>
@@ -1406,12 +1488,23 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>
struct scalar_type<non_native_vector_base<
T,
N,
ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>>
{
using type = typename non_native_vector_base<T, N>::data_t;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
{
using type = typename non_native_vector_base<T, N>::element_t;
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
@@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x6_t = typename vector_type<int32_t, 6>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;

View File

@@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
@@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return sign ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
@@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f4_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
@@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return sign ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}

View File

@@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
@@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
@@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<f6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
@@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
if(std::abs(value) > NumericLimits<bf6_t>::DataMaxNorm()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;

View File

@@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
is_same_v<ADataType, f6x32_pk_t> ||
is_same_v<ADataType, bf6x32_pk_t>)
{
v_a = type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
@@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||
is_same_v<BDataType, f6x32_pk_t> ||
is_same_v<BDataType, bf6x32_pk_t>)
{
v_b = type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));

View File

@@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
is_same_v<ADataType, f6x32_pk_t> ||
is_same_v<ADataType, bf6x32_pk_t>)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else
{
a_m_k_scaled(m, k) =
@@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||
is_same_v<BDataType, f6x32_pk_t> ||
is_same_v<BDataType, bf6x32_pk_t>)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else
{
b_k_n_scaled(k, n) =

View File

@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::bf6_convert_rne;
@@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest)
ASSERT_NEAR(max_bf6,
type_convert<float>(bf6_convert_rne(std::numeric_limits<float>::infinity())),
0.0f);
// convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6
ASSERT_NEAR(-max_bf6, type_convert<float>(bf6_convert_rne(-30.0f)), 0.0f);
ASSERT_NEAR(max_bf6, type_convert<float>(bf6_convert_rne(30.0f)), 0.0f);
// convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0
float less_than_subnorm = 0.03125f;
ASSERT_NEAR(0.0f, type_convert<float>(bf6_convert_rne(less_than_subnorm)), 0.0f);
@@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1)
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec};
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
@@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2)
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec[i]};
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
@@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1)
vector_type<bf6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
ASSERT_EQ(right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{test_vec};
});
// copy the vector
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
TEST(BF6, TestAllValues)
{
constexpr std::array<float, 64> e3m2ValuesOCP = {
// clang-format off
0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000,
0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000,
0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000,
2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000,
4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000,
8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000,
16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000,
-0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000,
-0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000,
-0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
-1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000,
-2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000,
-4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000,
-8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000,
-16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000
// clang-format on
};
constexpr uint8_t e3m2BitsOCP[] = {
// clang-format off
0b000000, 0b000001, 0b000010, 0b000011,
0b000100, 0b000101, 0b000110, 0b000111,
0b001000, 0b001001, 0b001010, 0b001011,
0b001100, 0b001101, 0b001110, 0b001111,
0b010000, 0b010001, 0b010010, 0b010011,
0b010100, 0b010101, 0b010110, 0b010111,
0b011000, 0b011001, 0b011010, 0b011011,
0b011100, 0b011101, 0b011110, 0b011111,
0b100000, 0b100001, 0b100010, 0b100011,
0b100100, 0b100101, 0b100110, 0b100111,
0b101000, 0b101001, 0b101010, 0b101011,
0b101100, 0b101101, 0b101110, 0b101111,
0b110000, 0b110001, 0b110010, 0b110011,
0b110100, 0b110101, 0b110110, 0b110111,
0b111000, 0b111001, 0b111010, 0b111011,
0b111100, 0b111101, 0b111110, 0b111111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("BF6 Table\n");
ck::static_for<0, 64, 1>{}([&](auto i) {
float fp = type_convert<float>(bf6_t(e3m2BitsOCP[i]));
ASSERT_EQ(fp, e3m2ValuesOCP[i]);
bf6_t bf6 = type_convert<bf6_t>(e3m2ValuesOCP[i]);
ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 5; j >= 0; --j)
{
printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]);
}
});
}

View File

@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
#include "ck/utility/env.hpp"
using ck::e8m0_bexp_t;
using ck::f4_convert_rne;
@@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest)
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_rne(std::numeric_limits<float>::max())), abs_tol);
// convert +/-7.0 to fp4 and back, check if clipped to +/-6.0
ASSERT_NEAR(-max_fp4, type_convert<float>(f4_convert_rne(-7.0f)), 0.0);
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_rne(7.0f)), 0.0);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
@@ -468,3 +474,54 @@ TEST(FP4, TestAsType32)
test_vec.at(i + 1));
});
}
TEST(FP4, TestAllValues)
{
constexpr std::array<float, 16> e2m1ValuesOCP = {
// clang-format off
0.0000000000, 0.5000000000,
1.0000000000, 1.5000000000,
2.0000000000, 3.0000000000,
4.0000000000, 6.0000000000,
-0.0000000000, -0.5000000000,
-1.0000000000, -1.5000000000,
-2.0000000000, -3.0000000000,
-4.0000000000, -6.0000000000
// clang-format on
};
constexpr uint8_t e2m1BitsOCP[] = {
// clang-format off
0b0000, 0b0001,
0b0010, 0b0011,
0b0100, 0b0101,
0b0110, 0b0111,
0b1000, 0b1001,
0b1010, 0b1011,
0b1100, 0b1101,
0b1110, 0b1111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("FP4 Table\n");
ck::static_for<0, 16, 1>{}([&](auto i) {
float fp = type_convert<float>(f4_t(e2m1BitsOCP[i]));
ASSERT_EQ(fp, e2m1ValuesOCP[i]);
f4_t fp4 = type_convert<f4_t>(e2m1ValuesOCP[i]);
ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 3; j >= 0; --j)
{
printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]);
}
});
}

View File

@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
@@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest)
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_rne(0.0f)), 0.0f);
// convert maximal f6_t to float and check if equal to max_fp6
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(max_fp6)), 0.0f);
// convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6
ASSERT_NEAR(-max_fp6, type_convert<float>(f6_convert_rne(-8.0f)), 0.0f);
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(8.0f)), 0.0f);
// convert maximal float to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(
max_fp6, type_convert<float>(f6_convert_rne(std::numeric_limits<float>::max())), 0.0f);
@@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1)
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec};
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i),
static_cast<f6_t>(test_vec[static_cast<int>(i)]))
<< " i = " << i << "; left = "
<< type_convert<float>(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i))
<< " -- right = "
<< type_convert<float>(static_cast<f6_t>(test_vec[static_cast<int>(i)])) << " ("
<< static_cast<int>(test_vec[static_cast<int>(i)]) << ")" << std::endl;
});
}
@@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2)
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec[i]};
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
@@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1)
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
ASSERT_EQ(right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{test_vec};
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
TEST(FP6, TestAllValues)
{
constexpr std::array<float, 64> e2m3ValuesOCP = {
// clang-format off
0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000,
2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000,
4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000,
-0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
-1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000,
-2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000,
-4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000
// clang-format on
};
constexpr uint8_t e2m3BitsOCP[] = {
// clang-format off
0b000000, 0b000001, 0b000010, 0b000011,
0b000100, 0b000101, 0b000110, 0b000111,
0b001000, 0b001001, 0b001010, 0b001011,
0b001100, 0b001101, 0b001110, 0b001111,
0b010000, 0b010001, 0b010010, 0b010011,
0b010100, 0b010101, 0b010110, 0b010111,
0b011000, 0b011001, 0b011010, 0b011011,
0b011100, 0b011101, 0b011110, 0b011111,
0b100000, 0b100001, 0b100010, 0b100011,
0b100100, 0b100101, 0b100110, 0b100111,
0b101000, 0b101001, 0b101010, 0b101011,
0b101100, 0b101101, 0b101110, 0b101111,
0b110000, 0b110001, 0b110010, 0b110011,
0b110100, 0b110101, 0b110110, 0b110111,
0b111000, 0b111001, 0b111010, 0b111011,
0b111100, 0b111101, 0b111110, 0b111111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("FP6 Table\n");
ck::static_for<0, 64, 1>{}([&](auto i) {
float fp = type_convert<float>(f6_t(e2m3BitsOCP[i]));
ASSERT_EQ(fp, e2m3ValuesOCP[i]);
f6_t fp6 = type_convert<f6_t>(e2m3ValuesOCP[i]);
ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 5; j >= 0; --j)
{
printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]);
}
});
}

View File

@@ -5,9 +5,12 @@
#include "mx_mfma_op.hpp"
using ck::bf6_t;
using ck::bf8_t;
using ck::e8m0_bexp_t;
using ck::f4_t;
using ck::f4x2_pk_t;
using ck::f6_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
@@ -17,13 +20,15 @@ using ck::type_convert;
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_km_kn_nm_test(ck::index_t init)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename AType,
typename BType,
typename CType,
ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
@@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init)
return pass;
}
const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
half_t,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
TEST(MFMA, BF8MFMA16x16x128)
{
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
half_t,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_mk_kn_mn_test(ck::index_t init)
TEST(MFMA, FP4MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mfma_kernel = ck::
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mfma_kernel, init);
return pass;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP4MFMA16x16x128)
TEST(MFMA, FP6MFMA16x16x128)
{
auto AB_init = 4;
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
float,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
float,
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
half_t,
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
@@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64)
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename AType,
typename BType,
typename CType,
ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_test(ck::index_t init)
{
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported");
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
using ScaleType = ck::e8m0_bexp_t; // biased exponent type
@@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
TEST(MXMFMA, MXFP8MFMA16x16x128)
{
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64)
{
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF8MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f6_t,
f6_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f6_t,
f6_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA16x16x128)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f4_t,
f4_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f4_t,
f4_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}

View File

@@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// clang-format on
static_assert(!is_packed_type_v<AType>, "Packed type is not supported");
static constexpr int32_t WAVE_SIZE = 64;
// Here we want to load from rows of A in chunks of 16 elements each.
@@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
constexpr index_t num_chunks = is_packed_type_v<AType> ? 1 : 2;
// Here we want to load from rows of A in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
constexpr uint32_t chunk_size = is_packed_type_v<AType> ? 32 : 16;
// each chunk is separated by offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
@@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D =
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
// FP8/6/4 Row {0-31} | {0-15}
// FP8 Col {0, 16} | {0, 16, 32, 48}
// FP6/4 Col {0, 32} | {0, 32, 64, 96}
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size);
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
row_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
using ARawT = typename scalar_type<AFragT>::type;
using AScalarChunkT = vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
union
{
AFragT frag;
AScalarFragT chunks[num_chunks];
AScalarChunkT chunks[num_chunks];
} fragA{};
const AScalarFragT* fragPtr;
const AScalarChunkT* fragPtr;
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v<AType>;
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v<AType>;
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
{
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragPtr = reinterpret_cast<AScalarChunkT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragA.chunks[chunk_idx] = *fragPtr;
}
@@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
constexpr index_t num_chunks = is_packed_type_v<BType> ? 1 : 2;
// Here we want to load from cols of B in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
constexpr uint32_t chunk_size = is_packed_type_v<BType> ? 32 : 16;
// each chunk is separated by an offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64
@@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D =
std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48}
threadIdx.x % BLOCK_N); // Col {0-31} | {0-15}
// FP8/6/4 Col {0-31} | {0-15}
// FP8 Row {0, 16} | {0, 16, 32, 48}
// FP6/4 Row {0, 32} | {0, 32, 64, 96}
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
col_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
using BRawT = typename scalar_type<BFragT>::type;
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 1 : 2);
using BRawT = typename scalar_type<BFragT>::type;
using BScalarChunkT = vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
union
{
BFragT frag;
BScalarFragT chunks[num_chunks];
BScalarChunkT chunks[num_chunks];
} fragB{};
const BScalarFragT* fragPtr;
const BScalarChunkT* fragPtr;
for(index_t chunk = 0; chunk < num_chunks; chunk++)
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v<BType>;
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v<BType>;
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
{
fragPtr =
reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + chunk * kMajorOffset);
fragB.chunks[chunk] = *fragPtr;
fragPtr = reinterpret_cast<BScalarChunkT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragB.chunks[chunk_idx] = *fragPtr;
}
return fragB.frag;
@@ -904,20 +921,22 @@ template <typename AType,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void matmul(const AType* a, const BType* b, CType* c)
__global__ void matmul(const typename packed_type<AType>::type* a,
const typename packed_type<BType>::type* b,
CType* c)
{
using PackedAType = typename packed_type<AType>::type;
constexpr auto packed_size_a = packed_type<AType>::packed_size;
using PackedBType = typename packed_type<BType>::type;
constexpr auto packed_size_b = packed_type<BType>::packed_size;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Load the inputs.
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
fragA = load_A_row_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
}
else
{
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
fragA = load_A_col_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
}
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
@@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
}
else
{
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
fragB = load_B_col_major<PackedBType, BFragT, BLOCK_K, BLOCK_N>(b);
}
// Matrix multiply-accumulate using MFMA units
@@ -979,21 +998,24 @@ template <typename AType,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
__global__ void matmul(const packed_type_t<AType>* a,
const ScaleType* xa,
const packed_type_t<BType>* b,
const ScaleType* xb,
CType* c)
{
using PackedAType = packed_type_t<AType>;
constexpr auto packed_size_a = packed_size_v<AType>;
using PackedBType = packed_type_t<BType>;
constexpr auto packed_size_b = packed_size_v<BType>;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
// Load the inputs.
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA =
load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
fragA = load_mx_A_row_major<PackedAType,
AFragT,
ScaleType,
AScaleFragT,
BLOCK_M,
BLOCK_K,
BLOCK_X>(a, xa, fragXa);
}
else
{
@@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
}
else
{
fragB =
load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
fragB = load_mx_B_col_major<PackedBType,
BFragT,
ScaleType,
BScaleFragT,
BLOCK_K,
BLOCK_N,
BLOCK_X>(b, xb, fragXb);
}
// Scaled Matrix multiply-accumulate using MFMA units
@@ -1151,6 +1181,11 @@ template <typename DeviceMFMA,
index_t BLOCK_X>
struct TestMXMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
@@ -1167,11 +1202,11 @@ struct TestMXMFMA
}
};
Tensor<ADataType> a_m_k(
Tensor<PackedAType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<ScaleType> a_scales(
f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{}));
Tensor<BDataType> b_n_k(
Tensor<PackedBType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<ScaleType> b_scales(
f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{}));
@@ -1183,51 +1218,44 @@ struct TestMXMFMA
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.5f}});
// NOTE: not all numbers are representable in FP8, BF8, etc.
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-2.0, 2.0});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-2.0, 2.0});
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(0, 1, time(nullptr)));
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(0, 1));
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(0, 1, time(nullptr) / 2));
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7}); // Z[-6,6]
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
@@ -1272,9 +1300,9 @@ struct TestMXMFMA
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
const Tensor<ScaleType>& a_scales = std::get<1>(host_tensors);
const Tensor<BDataType>& b = std::get<2>(host_tensors);
const Tensor<PackedBType>& b = std::get<2>(host_tensors);
const Tensor<ScaleType>& b_scales = std::get<3>(host_tensors);
Tensor<CDataType>& c_host = std::get<4>(host_tensors);
Tensor<CDataType>& c_device = std::get<5>(host_tensors);
@@ -1356,6 +1384,12 @@ template <typename DeviceMFMA,
index_t BLOCK_K>
struct TestMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
@@ -1372,9 +1406,9 @@ struct TestMFMA
}
};
Tensor<ADataType> a_m_k(
Tensor<PackedAType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_n_k(
Tensor<PackedBType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
@@ -1384,34 +1418,30 @@ struct TestMFMA
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{0.625f});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
// expect small round off errors that lead to FP8MFMA32x32x64 failures
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-5, 5});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
case 4:
// FP4 values case
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5});
// expect small round off errors that lead to FP8MFMA32x32x64 failures
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(1, 3));
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
// all initial values are representable in FP8/6, BF8/6 FP4 is missing 5
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7});
break;
}
@@ -1453,10 +1483,10 @@ struct TestMFMA
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
const Tensor<PackedBType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -1464,8 +1494,8 @@ struct TestMFMA
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<PackedAType,
PackedBType,
CDataType,
CPUAccDataType,
PassThrough,