Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 233e274077
commit 00247e3c29
83 changed files with 8193 additions and 2165 deletions

View File

@@ -34,19 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 16, 16, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 16, 16, 16, 16, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -31,8 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -6,6 +6,8 @@ list(APPEND GEMM_MX_INSTANCES
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp
device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp
device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp
device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp
)
@@ -13,6 +15,8 @@ set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES})

View File

@@ -13,12 +13,13 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF8 = bf8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using BF8 = bf8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,17 +41,19 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple<
#if 0 // TODO: Fix RRR
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
Row,
Row,
BF8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
F16,
32,
PassThrough,

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F4 = f4x2_pk_t;
using F16 = half_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using MFMA = tensor_layout::gemm::MFMA;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 384, 128, 16, 16, 16, 16, 2, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 512, 128, 16, 16, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 384, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 512, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 384, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 512, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 384, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 512, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
MFMA,
Row,
F4,
E8M0PK,
F4,
E8M0PK,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F4 = f4x2_pk_t;
using F16 = half_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#############################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
E8M0PK,
F4,
E8M0PK,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -39,19 +40,21 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple<
#if 0 // TODO: Fix CCR
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
BF16,
32,
PassThrough,

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
BF16,
32,
PassThrough,

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
F16,
32,
PassThrough,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
@@ -37,30 +37,30 @@ using device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = st
//#######################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 4, 16, 16, 16, 1, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 8, 8, 16, 16, 1, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 4, 32, 16, 16, 1, 2, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 16, 8, 16, 16, 1, 2, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 4, 32, 16, 16, 1, 2, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 16, 8, 16, 16, 1, 2, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>
// clang-format on
>;