Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 233e274077
commit 00247e3c29
83 changed files with 8193 additions and 2165 deletions

View File

@@ -77,33 +77,34 @@ struct ReferenceMXGemm : public device::BaseOperator
ComputeTypeA,
ComputeTypeB>;
Tensor<ComputeTypeA> a_m_k_scaled(arg.a_m_k_.mDesc);
Tensor<ComputeTypeB> b_k_n_scaled(arg.b_k_n_.mDesc);
const ck::index_t M = arg.a_m_k_.mDesc.GetLengths()[0];
const ck::index_t N = arg.b_k_n_.mDesc.GetLengths()[1];
assert(arg.a_m_k_.mDesc.GetLengths()[1] == arg.b_k_n_.mDesc.GetLengths()[0]);
const ck::index_t K = arg.a_m_k_.mDesc.GetLengths()[1];
const ck::index_t SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
Tensor<ComputeTypeA> a_m_k_scaled(HostTensorDescriptor({M, K}, {K, 1}));
Tensor<ComputeTypeB> b_k_n_scaled(HostTensorDescriptor({K, N}, {1, K}));
// printf("K: %d\n", K);
const auto M = arg.a_m_k_.mDesc.GetLengths()[0];
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
for(size_t m = 0; m < M; m++)
for(int m = 0; m < M; m++)
{
for(size_t k = 0; k < K; k++)
for(int k = 0; k < K; k++)
{
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
{
continue;
}
// TODO: add support for ColMajor layout as well
auto a_pack = arg.a_m_k_(m, k);
auto a_scale =
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
auto a_f4_lo = f4_t(a_pack.template unpack<>(Number<0>{}));
auto a_f4_hi = f4_t(a_pack.template unpack<>(Number<1>{}));
a_m_k_scaled(m, k) = type_convert<ComputeTypeA>(a_f4_lo) * a_scale;
a_m_k_scaled(m, k + 1) = type_convert<ComputeTypeA>(a_f4_hi) * a_scale;
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
@@ -124,25 +125,24 @@ struct ReferenceMXGemm : public device::BaseOperator
}
}
for(size_t n = 0; n < N; n++)
for(int n = 0; n < N; n++)
{
for(size_t k = 0; k < K; k++)
for(int k = 0; k < K; k++)
{
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
{
continue;
}
auto b_pack = arg.b_k_n_(k, n);
auto b_scale =
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
auto b_f4_lo = f4_t(b_pack.template unpack<>(Number<0>{}));
auto b_f4_hi = f4_t(b_pack.template unpack<>(Number<1>{}));
b_k_n_scaled(k, n) = type_convert<ComputeTypeB>(b_f4_lo) * b_scale;
b_k_n_scaled(k + 1, n) = type_convert<ComputeTypeB>(b_f4_hi) * b_scale;
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||

View File

@@ -23,6 +23,10 @@ using I32 = int32_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using I4 = ck::pk_i4_t;
using F4 = ck::f4x2_pk_t;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Empty_Tuple = ck::Tuple<>;
@@ -42,8 +46,9 @@ using BF16_Tuple = ck::Tuple<BF16>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;

View File

@@ -22,9 +22,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
F16,
32,
PassThrough,
@@ -36,23 +36,37 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
BF16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
I32,
F4,
I32,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Row,
Row,
BF8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
F16,
32,
PassThrough,
@@ -64,9 +78,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
BF16,
32,
PassThrough,
@@ -94,7 +108,8 @@ struct DeviceOperationInstanceFactory<
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
ck::tensor_operation::element_wise::PassThrough>,
enable_if_t<!is_same_v<BLayout, MFMA>>> // non-weight-pre-shuffle
{
using DeviceOp = DeviceGemmMX<ALayout,
BLayout,
@@ -127,6 +142,11 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
@@ -153,6 +173,73 @@ struct DeviceOperationInstanceFactory<
}
};
void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
MFMA,
Row,
F4,
I32,
F4,
I32,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
index_t ScaleBlockSize,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>,
enable_if_t<is_same_v<BLayout, MFMA>>>
{
using DeviceOp = DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, MFMA> && is_same_v<CLayout, Row>)
{
if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation