mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Optimized GEMMs for MX FP4/8 (#2294)
Adds V3 GEMM pipeline for MX FP4 and MX FP8 Adds V3 GEMM pipeline for MX FP4 with preshuffling Adds MXFP4 GEMM tests (#2275) Adds MXFP4 GEMM examples Adds MXFP4 GEMMs to ckProfiler Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
233e274077
commit
00247e3c29
@@ -77,33 +77,34 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
Tensor<ComputeTypeA> a_m_k_scaled(arg.a_m_k_.mDesc);
|
||||
Tensor<ComputeTypeB> b_k_n_scaled(arg.b_k_n_.mDesc);
|
||||
const ck::index_t M = arg.a_m_k_.mDesc.GetLengths()[0];
|
||||
const ck::index_t N = arg.b_k_n_.mDesc.GetLengths()[1];
|
||||
assert(arg.a_m_k_.mDesc.GetLengths()[1] == arg.b_k_n_.mDesc.GetLengths()[0]);
|
||||
const ck::index_t K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
const ck::index_t SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
|
||||
Tensor<ComputeTypeA> a_m_k_scaled(HostTensorDescriptor({M, K}, {K, 1}));
|
||||
Tensor<ComputeTypeB> b_k_n_scaled(HostTensorDescriptor({K, N}, {1, K}));
|
||||
// printf("K: %d\n", K);
|
||||
|
||||
const auto M = arg.a_m_k_.mDesc.GetLengths()[0];
|
||||
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
|
||||
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
|
||||
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
else
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// TODO: add support for ColMajor layout as well
|
||||
auto a_pack = arg.a_m_k_(m, k);
|
||||
auto a_scale =
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
auto a_f4_lo = f4_t(a_pack.template unpack<>(Number<0>{}));
|
||||
auto a_f4_hi = f4_t(a_pack.template unpack<>(Number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = type_convert<ComputeTypeA>(a_f4_lo) * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = type_convert<ComputeTypeA>(a_f4_hi) * a_scale;
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
|
||||
is_same_v<ADataType, bf6x16_pk_t> ||
|
||||
@@ -124,25 +125,24 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
|
||||
for(size_t n = 0; n < N; n++)
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
else
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto b_pack = arg.b_k_n_(k, n);
|
||||
auto b_scale =
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
auto b_f4_lo = f4_t(b_pack.template unpack<>(Number<0>{}));
|
||||
auto b_f4_hi = f4_t(b_pack.template unpack<>(Number<1>{}));
|
||||
b_k_n_scaled(k, n) = type_convert<ComputeTypeB>(b_f4_lo) * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = type_convert<ComputeTypeB>(b_f4_hi) * b_scale;
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
|
||||
is_same_v<BDataType, bf6x16_pk_t> ||
|
||||
|
||||
@@ -23,6 +23,10 @@ using I32 = int32_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using E8M0PK = int32_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
@@ -42,8 +46,9 @@ using BF16_Tuple = ck::Tuple<BF16>;
|
||||
using F32_F32_Tuple = ck::Tuple<F32, F32>;
|
||||
|
||||
// GEMM layout
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using MFMA = ck::tensor_layout::gemm::MFMA;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
@@ -22,9 +22,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
@@ -36,23 +36,37 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
F4,
|
||||
I32,
|
||||
F4,
|
||||
I32,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Row,
|
||||
Row,
|
||||
BF8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
@@ -64,9 +78,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
E8M0PK,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
@@ -94,7 +108,8 @@ struct DeviceOperationInstanceFactory<
|
||||
ScaleBlockSize,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
ck::tensor_operation::element_wise::PassThrough>,
|
||||
enable_if_t<!is_same_v<BLayout, MFMA>>> // non-weight-pre-shuffle
|
||||
{
|
||||
using DeviceOp = DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
@@ -127,6 +142,11 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
@@ -153,6 +173,73 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
};
|
||||
|
||||
void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
MFMA,
|
||||
Row,
|
||||
F4,
|
||||
I32,
|
||||
F4,
|
||||
I32,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>,
|
||||
enable_if_t<is_same_v<BLayout, MFMA>>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, MFMA> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user