New instances for gemm_multiply_multiply_weightpreshuffle operator (#2061)

* Add new instances for weight_preshuffle for f8->bf16

* Add new instances for weight_preshuffle for f8->f16

* clang formatted

---------

Co-authored-by: Khushbu Agarwal <khuagar@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-04-08 15:14:53 -07:00
committed by GitHub
parent 2c8132126c
commit 263ff689e0
2 changed files with 18 additions and 2 deletions

View File

@@ -100,7 +100,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
// clang-format on
>;

View File

@@ -115,7 +115,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
// clang-format on
>;