From 72054549e719b8e19f97fcd7c9b81b68da85d020 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Thu, 5 Jun 2025 13:54:15 -0600 Subject: [PATCH] Optimized GEMMs for MX FP4/8 (#2294) Adds V3 GEMM pipeline for MX FP4 and MX FP8 Adds V3 GEMM pipeline for MX FP4 with preshuffling Adds MXFP4 GEMM tests (#2275) Adds MXFP4 GEMM examples Adds MXFP4 GEMMs to ckProfiler Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko Co-authored-by: aska-0096 Co-authored-by: lalala-sh Co-authored-by: OscarXu Co-authored-by: mtgu0705 Co-authored-by: Ding, Yi Co-authored-by: feifei14119 Co-authored-by: Lin, Qun Co-authored-by: joye Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> [ROCm/composable_kernel commit: 00247e3c297032a2cbdaae465113648ec1857d3f] --- CHANGELOG.md | 2 +- example/01_gemm/CMakeLists.txt | 6 + ..._add_fastgelu_xdl_lds_direct_load_fp32.cpp | 4 +- .../batched_gemm_xdl_fp8_rowwise_v3.cpp | 12 +- .../splitK_gemm_xdl_lds_direct_load_fp16.cpp | 4 +- example/67_gemm_microscaling/CMakeLists.txt | 37 +- example/67_gemm_microscaling/gemm_mx_bf8.cpp | 23 +- .../67_gemm_microscaling/gemm_mx_common.hpp | 260 +- example/67_gemm_microscaling/gemm_mx_fp4.cpp | 105 + .../gemm_mx_fp4_bpreshuffle.cpp | 105 + example/67_gemm_microscaling/gemm_mx_fp8.cpp | 23 +- .../67_gemm_microscaling/gemm_mx_fp8_bf8.cpp | 19 +- example/CMakeLists.txt | 8 +- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 164 +- ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 2 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v1.hpp | 4 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 20 +- ...ipeline_xdlops_mx_bpreshuffle_selector.hpp | 68 + ...kwise_gemm_pipeline_xdlops_mx_selector.hpp | 55 +- ...kwise_gemm_pipeline_xdlops_v1_ab_scale.hpp | 2 +- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 525 ++-- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 2 +- ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 2 +- ...ckwise_gemm_pipeline_xdlops_v3_b_scale.hpp | 2 +- .../blockwise_gemm_pipeline_xdlops_v3_mx.hpp | 1090 ++++++++ ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 1042 ++++++++ .../blockwise_gemm_pipeline_xdlops_v5.hpp | 2 +- ...roup_tensor_slice_transfer_direct_load.hpp | 63 +- .../gpu/device/device_gemm_mx.hpp | 38 + .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 563 +--- ...m_xdl_splitk_c_shuffle_lds_direct_load.hpp | 2 + .../element/unary_element_wise_operation.hpp | 7 + ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 36 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 3 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 5 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 18 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 986 +++---- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 2295 +++++++++++++++++ ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 33 +- .../threadwise_tensor_slice_transfer.hpp | 249 +- .../threadwise_tensor_slice_transfer_util.hpp | 12 + .../threadwise_tensor_slice_transfer_v3r1.hpp | 7 +- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 9 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 2 - .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 119 +- include/ck/utility/amd_buffer_addressing.hpp | 24 +- .../amd_buffer_addressing_builtins.hpp | 10 +- include/ck/utility/amd_xdlops.hpp | 220 +- include/ck/utility/blkgemmpipe_scheduler.hpp | 14 +- include/ck/utility/data_type.hpp | 159 +- include/ck/utility/dtype_vector.hpp | 7 + include/ck/utility/functional2.hpp | 43 +- include/ck/utility/integral_constant.hpp | 14 +- include/ck/utility/type_convert.hpp | 5 + ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- .../cpu/reference_mx_gemm.hpp | 68 +- .../device_operation_instance_factory.hpp | 9 +- .../tensor_operation_instance/gpu/gemm_mx.hpp | 105 +- ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 26 +- ...ect_load_f32_f32_f32_km_kn_mn_instance.cpp | 4 +- ...ect_load_f32_f32_f32_km_nk_mn_instance.cpp | 4 +- ...ect_load_f32_f32_f32_mk_kn_mn_instance.cpp | 4 +- ...ect_load_f32_f32_f32_mk_nk_mn_instance.cpp | 4 +- .../gpu/gemm_mx/CMakeLists.txt | 4 + ...device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp | 33 +- ...l_bf8_f8_f16_mk_kn_mn_default_instance.cpp | 4 +- ...evice_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp | 73 + ..._f4_f4_f16_mk_mfma_mn_default_instance.cpp | 32 + .../device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp | 65 + ...dl_f4_f4_f16_mk_nk_mn_default_instance.cpp | 32 + ...device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp | 35 +- ...l_f8_f8_bf16_km_nk_mn_default_instance.cpp | 4 +- ...device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 29 +- ...l_f8_f8_bf16_mk_nk_mn_default_instance.cpp | 4 +- .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 29 +- ...dl_f8_f8_f16_mk_nk_mn_default_instance.cpp | 4 +- ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 46 +- .../include/profiler/profile_gemm_mx_impl.hpp | 534 ++++ profiler/src/CMakeLists.txt | 6 + profiler/src/profile_gemm_mx.cpp | 155 ++ test/gemm_mx/test_gemm_mx.cpp | 33 +- test/gemm_mx/test_gemm_mx_util.hpp | 434 +--- test/mx_mfma_op/mx_mfma_op.hpp | 45 +- 83 files changed, 8193 insertions(+), 2165 deletions(-) create mode 100644 example/67_gemm_microscaling/gemm_mx_fp4.cpp create mode 100644 example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_mx_impl.hpp create mode 100644 profiler/src/profile_gemm_mx.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ec0c1ecce..aecf16d83d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM -* Added GEMM pipeline for microscaling (MX) data types +* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 24292be4fe..e6a26ecafd 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -39,6 +39,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16") +example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) + + list(APPEND gpu_list gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp index de7af85fb3..67b3e646f7 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm #include #include @@ -71,9 +71,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 256, // BlockSize 256, // MPerBlock 128, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 64, // KPerBlock + 16, // AK1 + 16, // BK1 32, // MPerXDL 32, // NPerXDL 4, // MXdlPerWave @@ -84,14 +84,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM + 0, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + 0, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index 97a3f89e5e..fc55019fc4 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 1a1db51c37..86d90674e1 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -6,6 +6,39 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) -add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) +#add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) TOFO: Fix RRR +add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) + +add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle) + +#add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp) +# add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) TODO: Fix + +#add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp) +#add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) + +#add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp) +# add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) TODO: Fix + +#add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) +#add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) + +set(FP4_MXGEMM_OPTIONS) +list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1") +#list(APPEND FP4_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0) +example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) +# example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +# example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +# example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) +# example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + +set(FP8_MXGEMM_OPTIONS) +list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") +#list(APPEND FP8_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0) +example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp index 8e341fb591..58f2dcb010 100644 --- a/example/67_gemm_microscaling/gemm_mx_bf8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -21,11 +21,11 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 128; +constexpr ck::index_t KPerBlock = 256; constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; -constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout @@ -45,32 +45,32 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle ScaleBlockSize, // ScaleBlockSize: Scaling block size 128, // BlockSize: Thread block size 128, // MPerBlock - 16, // NPerBlock + 32, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL 4, // MXdlPerWave - 1, // NXdlPerWave - S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 2, // NXdlPerWave + S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + true, // ABlockLdsExtraM + S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 16, // BBlockTransferSrcScalarPerVector 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 2, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer ADataType, // ComputeTypeA @@ -83,6 +83,7 @@ int main(int argc, char* argv[]) ADataType, BDataType, XDataType, + XDataType, CDataType, ALayout, BLayout, diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 99ed2a23b9..30df8ccd37 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -36,6 +37,8 @@ struct ExecutionConfig final int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) + int warm_up = 10; + int repeat = 10; }; struct ProblemSizeSplitK final @@ -86,6 +89,8 @@ bool parse_cmd_args(int argc, if(argc >= 12) { problem_size.KBatch = std::stoi(argv[11]); + config.warm_up = std::stoi(argv[12]); + config.repeat = std::stoi(argv[13]); } } else @@ -103,10 +108,90 @@ bool parse_cmd_args(int argc, return true; } +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, + // 2-k))); + + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + template bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { + constexpr bool BPreShuffle = ck::is_same_v; + using BRefLayout = ck::conditional_t; auto M = problem_size.M; auto N = problem_size.N; @@ -131,28 +218,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto f_host_tensor_descriptor = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { if constexpr(std::is_same_v) - { return HostTensorDescriptor({row, col}, {stride, 1}); - } else - { return HostTensorDescriptor({row, col}, {1, stride}); - } }; - auto f_get_default_stride = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { if(stride == -1) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) - { return static_cast(col); - } else - { return static_cast(row); - } } else return static_cast(stride); @@ -172,16 +250,30 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c using AScaleLayout = Row; using BScaleLayout = Col; - auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + auto b_k_n = + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); + auto b_input = b_k_n; + if constexpr(BPreShuffle) + b_input = std::make_shared>( + f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size + // scales for A and B Tensor a_m_k_scale(f_host_tensor_descriptor( - M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A - Tensor b_k_n_scale(f_host_tensor_descriptor( - K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + // shuffled scales for A and B + Tensor a_shuffled_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_shuffled_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification @@ -192,18 +284,31 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c { std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n->mDesc << std::endl; std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; } + auto a_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else + return ck::type_convert(x); + }; + auto b_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else + return ck::type_convert(x); + }; + switch(config.init_method) { case 0: // Initializations for development and debugging - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); + ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); if(config.verbosity > 0) { std::cout << "Init A = {1}" << std::endl; @@ -216,29 +321,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - - if constexpr(ck::is_same_v) - { - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - } - else - { - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); - } - + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_k_n->GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + static_assert(ck::is_same_v); + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); break; @@ -249,20 +345,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c } } + preShuffleScaleBuffer>(a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if constexpr(BPreShuffle) + { + int NPerXdl = 16; // Fixed 16 + preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); + } + if(config.verbosity > 0) std::cout << "Device memory allocation..." << std::endl; - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); if(config.verbosity > 0) std::cout << "Upload data to device..." << std::endl; a_device_buf.ToDevice(a_m_k.mData.data()); - a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); + b_device_buf.ToDevice(b_input->mData.data()); + b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); + if(config.verbosity > 0) std::cout << "Done." << std::endl; @@ -275,9 +384,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), M, N, @@ -299,13 +408,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c "not consistent with the supported device_gemm arguments."); } + std::size_t total_size = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + + a_shuffled_scale.GetElementSpaceSizeInBytes() + + b_shuffled_scale.GetElementSpaceSizeInBytes(); + const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size); + const int rotating_count = std::max(1, std::min(config.repeat, static_cast(total_cnt))); if(config.verbosity > 0) { std::cout << "Computing GEMM on device..." << std::endl << std::endl; } - float ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + float ave_time = invoker.Run(argument, + StreamConfig{nullptr, + config.time_kernel, + config.verbosity, + config.warm_up, + config.repeat, + rotating_count > 1, + rotating_count}); bool res_verified = true; if(config.do_verification > 0) @@ -332,7 +454,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto ref_argument = ref_gemm.MakeArgument(a_m_k, a_m_k_scale, - b_k_n, + *b_k_n, b_k_n_scale, c_m_n_host_result, PassThrough{}, @@ -347,20 +469,21 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Comparing results..." << std::endl; } - if(config.init_method == 0) - { - auto expected = static_cast(K); - auto computed = type_convert(c_m_n_device_result(1, 12)); + // if(config.init_method == 0) + // { + // auto expected = static_cast(K); + // auto computed = type_convert(c_m_n_device_result(1, 12)); - res_verified = res_verified && std::abs(expected - computed) <= 0.0f; - std::cout << "\nExpected vs Computed: " << expected << " vs " << computed - << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl - << std::endl; - } + // res_verified = res_verified && std::abs(expected - computed) <= 0.0f; + // std::cout << "\nExpected vs Computed: " << expected << " vs " << computed + // << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl + // << std::endl; + // } - res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!"); + res_verified = + res_verified && + ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1); if(config.verbosity > 0 && res_verified) std::cout << "Verification Successful!" << std::endl; @@ -377,13 +500,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // partial sums(K/ScaleBlockSize)] // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + std::size_t num_btype = + sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = static_cast(num_btype) / 1e6f / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << device_op.GetTypeString() << std::endl; @@ -396,6 +520,7 @@ template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..562b2fdb17 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f4x2_pk_t; +using BDataType = ck::f4x2_pk_t; +// using ADataType = ck::f4_t; +// using BDataType = ck::f4_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = MFMA; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +// AB DataType: f4x2_pk_t +// Mathmatically, all numbers are represented as f4x2. +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 512, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp index 9fc5666197..e6fe791178 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -25,7 +25,7 @@ constexpr ck::index_t KPerBlock = 256; constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; -constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout @@ -49,26 +49,26 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle KPerBlock, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 16, // BBlockTransferSrcScalarPerVector 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched @@ -83,6 +83,7 @@ int main(int argc, char* argv[]) ADataType, BDataType, XDataType, + XDataType, CDataType, ALayout, BLayout, diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp index ce4ebc0a40..fdc4ace471 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -24,7 +24,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; -constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout @@ -43,30 +43,30 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size 256, // BlockSize: Thread block size - 256, // MPerBlock - 256, // NPerBlock - 128, // KPerBlock + 128, // MPerBlock + 128, // NPerBlock + 256, // KPerBlock 16, // AK1 8, // BK1 16, // MPerXDL 16, // NPerXDL - 8, // MXdlPerWave - 8, // NXdlPerWave - S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 false, // ABlockLdsExtraM - S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 2, 1>, // BBlockTransferSrcAccessOrder 1, // BBlockTransferSrcVectorDim 16, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock @@ -82,6 +82,7 @@ int main(int argc, char* argv[]) ADataType, BDataType, XDataType, + XDataType, CDataType, ALayout, BLayout, diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c86b434212..54d9f13453 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -222,12 +222,18 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() - + #message("add_example returns ${result}") set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) +function(example_compile_options EXAMPLE_NAME) + if(TARGET ${EXAMPLE_NAME}) + target_compile_options(${EXAMPLE_NAME} ${ARGN}) + endif() +endfunction(example_compile_options) + # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index ebe075b55d..f366f309ff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -35,6 +35,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base using ComputeTypeB = BDataType; using AccType = float; // for now only support V_MFMA_SCALE_F32 + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -48,17 +51,24 @@ struct BlockwiseGemmXdlops_mx_pipeline_base static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + // static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); - static constexpr auto xdlops_gemm = - XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t AMmaKStride = KPack; static constexpr index_t BMmaKStride = KPack; //> store rows/cols into thread registers in chunks of 16 //> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] - static constexpr index_t KThreadChunk = 16; + static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA); static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; @@ -67,22 +77,29 @@ struct BlockwiseGemmXdlops_mx_pipeline_base static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - using HotLoopInstList = - ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + // Hardcode to 2, for better 8-bit access pattern + + static constexpr index_t MXdlPack = 2; + static constexpr index_t NXdlPack = 2; + static constexpr index_t KXdlPack = 2; + + using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< // + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + A_K1, + B_K1, + A_K1, + B_K1, + MRepeat, + NRepeat, + MPerXDL, + NPerXDL, + xdlops_gemm.KPerXdlops, + (packed_size_v > 1 || packed_size_v > 1)>; static_assert(KPerThread % KPack == 0, "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); @@ -116,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); - return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); + return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -127,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); - return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); + return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); } template @@ -142,24 +159,27 @@ struct BlockwiseGemmXdlops_mx_pipeline_base const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple( + make_unmerge_transform(make_tuple(MRepeat / MXdlPack, MWaves, MXdlPack, MPerXDL))), make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); + make_tuple(Sequence<0, 1, 2, 3>{})); constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple( + make_unmerge_transform(make_tuple(NRepeat / NXdlPack, NWaves, NXdlPack, NPerXDL))), make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); + make_tuple(Sequence<0, 1, 2, 3>{})); + // We pack 2 mfma in M/N direction, so we need to divide by 2 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0]; const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0]; return make_tuple(c_thread_m, c_thread_n); } - using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. @@ -179,13 +199,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), - Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(), + Tuple5 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); @@ -221,6 +240,28 @@ struct BlockwiseGemmXdlops_mx_pipeline_base make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); } + // XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + I1, + Number{}, + Number{}, + M0, + M1, + M2, + N)); + } + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() { constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); @@ -262,6 +303,23 @@ struct BlockwiseGemmXdlops_mx_pipeline_base return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); } + // XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( + c_block_desc_m0_n0_m1_n1_m2_n2); + } + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() { constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = @@ -314,45 +372,47 @@ struct BlockwiseGemmXdlops_mx_pipeline_base c_grid_desc_g_m0_n0_m1_n1_m2_n2); } - static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; - static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k; protected: // M1, N1 as double buffer index // Read buffer + Compute buffer // A[M0, M1, M2, KPack] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( - make_tuple(Number{}, I1, Number{}, Number{}), - make_tuple( - Number{}, Number{}, Number{}, I1)); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, I1, Number{}, Number{}, Number{})); // B[N0, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( - make_tuple(Number{}, I1, Number{}, Number{}), - make_tuple( - Number{}, Number{}, Number{}, I1)); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, I1, Number{}, Number{}, Number{})); // C[M, N, NumRegXdlops] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, 1, 1, 1, KThreadChunk>, + Sequence<0, 1, 2, 3, 4>, + 4, A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, 1, 1, 1, KThreadChunk>, + Sequence<0, 1, 2, 3, 4>, + 4, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index e5fe92a50d..8b227a8aa1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; + XdlopsGemm{}; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 1d27a74bd7..d8f11572a8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -270,10 +270,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1, f8_t>) + // On gfx950, we have mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; + else + return 1; + }(); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp new file mode 100644 index 0000000000..7d21c44504 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXBPreshufflePipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}; + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp index c1433659d6..52ab86b6d4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -4,38 +4,9 @@ #pragma once #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp" namespace ck { - -/** - * @brief Define matrix data types that have hardware support for MX GEMMs - */ -template -static constexpr bool is_scale_mfma_data_type() -{ - return is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v; -} - -/** - * @brief Define scale data types that have hardware support for MX GEMMs - */ -template -static constexpr bool is_scale_mfma_scale_type() -{ - return is_same_v; -} - -/** - * @brief Combination of data types that have hardware support for MX GEMMs - */ -template -static constexpr bool scale_mfma_hw_support() -{ - return is_scale_mfma_data_type() && is_scale_mfma_data_type() && - is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); -} - template {}; } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_mx{}; + } else { std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 8375e81fa0..ea4f5e4a28 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + static constexpr auto AScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + static constexpr auto BScalesPerXdlopsRun = + (BPackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRunPerThread = - ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + static constexpr auto ScalesPerXdlopsRunPerThreadA = + AScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + static constexpr auto ScalesPerXdlopsRunPerThreadB = + BScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; __host__ static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -232,76 +253,58 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); }); a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); }); // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); // Prefetch b_scales - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); - b_scale_thread_buf(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); }); b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); }); // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT, LDS, GDS, Constant and Message + block_sync_lds(); // Initialize C c_thread_buf.Clear(); @@ -314,13 +317,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -335,160 +333,184 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); + static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { // read block data in chunks to assemble correct thread vectors - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, Number{}), - b_thread_buf); - }); + static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); + // load for next k loop + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_assert(0 < ScalesPerXdlopsRunPerThread, + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, "Must have at least one scale per Xdlops per Thread."); - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; // Pack scale_thread_buf into scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { b_scale_thread_vec.template AsType()(s) = b_scale_thread_buf[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + vector_type a_thread_vec; + vector_type b_thread_vec; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); }); }); }); // Prefetch a_scales - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); }); a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); // Prefetch b_scales - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); - b_scale_thread_buf(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); }); b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); }); // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT and LGKM_CNT block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); i += 1; } while(i < (num_loop - 1)); @@ -497,87 +519,128 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; static_for<0, MRepeat, 1>{}([&](auto m0) { - // read block data in chunks to assemble correct thread - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); + static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - // read block data in chunks to assemble correct thread - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto b_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, Number{}), - b_thread_buf); - }); + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); }); }); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { a_scale_thread_vec.template AsType()(s) = a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { b_scale_thread_vec.template AsType()(s) = b_scale_thread_buf[Number{}]; }); - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + vector_type a_thread_vec; + vector_type b_thread_vec; - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); }); }); }); @@ -587,20 +650,16 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}, Number{}, Number{})); - - // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + make_tuple(Number{}, + Number{}, + Number{})); // TODO: make this field protected when b_scale_thread_copy_ is moved // here static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from b_scale_grid to b_scale_thread_buf - static constexpr auto b_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + make_tuple(Number{}, + Number{}, + Number{})); protected: using Base::a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 171a232c0f..b5d6180ab3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -177,8 +177,8 @@ struct BlockwiseGemmXdlops_pipeline_v3 +struct BlockwiseGemmXdlops_pipeline_v3_mx +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + // __builtin_amdgcn_s_waitcnt(3952); + // block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(scale_mem_buf), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(scale_mem_buf), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp new file mode 100644 index 0000000000..7e11304e2f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -0,0 +1,1042 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_stage1 = + num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); + + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; + + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple( + I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + constexpr index_t SwitchM = MRepeat - LocalPrefetchStages; + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, + a_grid_buf, + a_block_desc, + a_block_bufs(scale_comp_buf)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + constexpr auto lds_buf = + m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] + // Order: 1 0 3 2 4 + static constexpr auto ARegBuf = 2; + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + // using Base::a_thread_copy_; + // using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index b6a4f05502..99934fa74e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v5 @@ -61,6 +63,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; @@ -96,8 +99,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive // elements = 64 consecutive DWORDs. +#if defined(__gfx950__) + int num_contiguous_dwords = 4; +#else int num_contiguous_dwords = 1; - bool is_contiguous = true; +#endif + bool is_contiguous = true; static_for<0, nDim, 1>{}([&](auto i) { if(is_contiguous) { @@ -141,11 +148,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "When loading more than one element per thread at once, the contiguous " "dimension must be the same between source and destination."); - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); - static_assert(bytes_per_thread_load == dword_bytes, - "Direct load transfer requires each thread to load exactly a single " - "DWORD of data."); + // constexpr auto dword_bytes = 4; + // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + // static_assert(bytes_per_thread_load == dword_bytes, + // "Direct load transfer requires each thread to load exactly a single " + // "DWORD of data."); static_assert(nDim == remove_cvref_t::GetNumOfDimension() && nDim == remove_cvref_t::GetNumOfDimension() && @@ -156,18 +163,45 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "The number of threads cannot be less than the number of elements in " "thread cluster lengths."); - static_assert( - AreThreadClusterLengthsValid(), - "Thread cluster lengths are incorrect. They must be set in a way that allows a single " - "wavefront to write contiguous DWORDs into LDS memory. "); + // static_assert( + // AreThreadClusterLengthsValid(), + // "Thread cluster lengths are incorrect. They must be set in a way that allows a single + // " "wavefront to write contiguous DWORDs into LDS memory. "); const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + constexpr auto wave_cluster_lengths = generate_sequence_v2( + [&](auto i) { + // FIXME: wave parallelism is not always in that dimension. + // The ThreadClusterLengths{} must be bigger than wave_num; + if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3)) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths; + constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; + constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); + + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); - SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -215,7 +249,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad // Loop over the destination block and copy data. static_ford{}([&](auto ordered_dst_access_idx) { const auto src_offset = src_coord_.GetOffset(); - const auto dst_offset = dst_coord_.GetOffset(); + const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); // Check if src data is not in the logic padding area. const bool is_src_valid = @@ -303,7 +337,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad } private: - static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}); + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp index e89185a35c..0562e452ac 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp @@ -45,6 +45,44 @@ struct DeviceGemmMX : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +struct DeviceGemmMX_BPreshuffle : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 2c34be9007..ed168195ec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -162,56 +163,108 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX { // GridwiseGemm - using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>; + using GridwiseGemm = conditional_t< // + !is_same_v, + GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>, + GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>>; using Argument = typename GridwiseGemm::Argument; @@ -304,385 +357,45 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 1) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - if(arg.KBatch > 1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - if(arg.KBatch > 1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - Run(kernel); - } - } - } - else - { - if(arg.KBatch > 1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 + constexpr auto TailNumChoices = []() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) + return Tuple>{}; + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + return Tuple, constant>{}; + else + static_assert(false, "Unexpected BlkGemmPipelineVer!"); + }(); + constexpr bool Use2LDS = []() { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + return false; + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + return true; + else + static_assert(false, "Unexpected BlkGemmPipelineVer!"); + }(); + const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split); + using BoolChoices = Tuple; + static_for_product>{}( + [&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) { + constexpr auto CGlobalMemoryDataOperation = + KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd + : InMemoryDataOperationEnum::Set; + if(mainloop_choice.value == has_main_k_block_loop && + KBatch_cond_choice.value == (arg.KBatch > 1) && + tail_num_choice.value == tail_num) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< // + Use2LDS, + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; Run(kernel); } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - + }); return ave_time; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp index d704d04054..eda966c48a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK + __host__ __device__ void operator()(f4x2_pk_t& y, + const f4x2_pk_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(double& y, const double& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 7781d1def3..1e79d67f93 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -173,18 +173,34 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { - // A matrix in LDS memory, destination of blockwise copy. - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some specific + // setting + return make_naive_tensor_descriptor_packed( + make_tuple(AK0PerBlock, Number{}, AK1)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(AK1, Number{}, I1)); + } } __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { - // B matrix in LDS memory, destination of blockwise copy. - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some specific + // setting + return make_naive_tensor_descriptor_packed( + make_tuple(BK0PerBlock, Number{}, BK1)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(BK1, Number{}, I1)); + } } __host__ __device__ static constexpr auto @@ -566,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ThreadGroupTensorSliceTransfer_DirectLoad, ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, ADataType, AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, 2, ABlockTransferScalarPerVector>( @@ -582,10 +600,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ThreadGroupTensorSliceTransfer_DirectLoad, BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, BDataType, BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, 2, BBlockTransferScalarPerVector>( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 0dbbc2a5e9..338674ae85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -256,8 +256,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || (is_same::value && lcm_AK1_BK1 <= 8) || + // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) + KPerBlock < 128 && MPerXdl == 16)) ? true : false; static constexpr auto is_scale_mfma = false; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 38ce9536ab..812e41ba58 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -184,8 +184,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || (is_same::value && lcm_AK1_BK1 <= 8) || + // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) + KPerBlock < 128 && MPerXdl == 16)) ? true : false; static constexpr auto is_scale_mfma = false; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 8fb955c561..cb22f99fc2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -173,15 +173,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle lcm_AK1_BK1 < 32)) ? true : false; - static constexpr auto is_scale_mfma = false; - static constexpr auto mfma = MfmaSelector{}; - static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); - static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); static constexpr index_t KPackPerGroup = KPack / KGroup; static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index f877912329..e32301fcd2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -14,26 +14,30 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" #include "ck/utility/env.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" namespace ck { +#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same // kernel function Blockers: // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on // two lds chunks. // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -template -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -54,17 +58,18 @@ __global__ void #endif // end of if (defined(__gfx9__)) } -template -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ // Pass two lds pointer is the key to tell compiler that ds_read/write @@ -76,9 +81,10 @@ __global__ void GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, p_shared_0, p_shared_1, karg); @@ -87,6 +93,7 @@ __global__ void ignore = karg; #endif // end of if (defined(__gfx9__)) } +#endif template {}; static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; // K1 should be Number<...> static constexpr auto AK0Number = Number{}; @@ -163,10 +172,19 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + //> KPack is at least the k_per_blk of selected mfma // // Should be a multiple of k_per_blk. // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + is_scale_mfma>::selected_mfma.k_per_blk / + APackedSize); using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -247,19 +252,33 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 return math::integer_divide_ceil(N, NPerBlock); } - template + template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - return transform_tensor_descriptor( + constexpr auto permuted_desc = transform_tensor_descriptor( TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); } __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( @@ -304,12 +323,28 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // pad M, but not K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(MPad), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return a_grid_desc; } else if constexpr(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::NKPadding) @@ -335,12 +370,29 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad M or K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; } } @@ -363,6 +415,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(!(is_same_v, pk_i4_t> && GemmSpec != GemmSpecialization::Default), "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + (GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MPadding)), + "f4x2_pk_t does not support K padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) @@ -423,12 +479,30 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad N or K const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), + make_tuple( + make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return b_grid_desc_bk0_n_bk1; + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; } else { @@ -456,20 +530,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 template __host__ __device__ static constexpr auto - MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) { constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); } template __host__ __device__ static constexpr auto - MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) { constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); } __host__ __device__ static auto @@ -627,10 +703,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 bool is_reduce_ = false) : Problem{M_, N_, - K_, - StrideA_, + K_ / APackedSize, + StrideA_ / APackedSize, StrideScaleA_, - StrideB_, + StrideB_ / BPackedSize, StrideScaleB_, StrideC_, k_batch_}, @@ -675,7 +751,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 { if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead; } else if constexpr(is_same_v) { @@ -690,34 +766,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead / BPackedSize; + b_k_split_offset = k_id * karg.KRead; } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; + b_k_split_offset = k_id * k0_offset; } } // Calculate A scale offset - if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; - } - else if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA; - } + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl; // Calculate B scale offset - if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB; - } - else if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; - } + b_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl; if(k_id < (karg.KBatch - 1)) { @@ -750,47 +814,28 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in LDS return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(AK1Number, Number{}, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. else if constexpr(is_same::value) { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return a_lds_block_desc_permuted; } else // ColumnMajor A { @@ -887,46 +932,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in lds return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(BK1Number, Number{}, I1)); } else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_permuted; } else // RowMajor B { @@ -1044,9 +1070,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, @@ -1081,8 +1107,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), c_block_size * sizeof(CShuffleDataType)); } @@ -1093,7 +1119,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - static_assert(KPerBlock % ScaleBlockSize == 0, + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, "KPerBlock should be multiple of ScaleBlockSize"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || @@ -1269,7 +1295,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } } } - +#if 0 // check gridwise gemm pipeline const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); @@ -1280,7 +1306,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 return false; } } - +#endif // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -1318,6 +1344,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + static_assert(is_same_v && + is_same_v, + "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!"); + template ( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1392,67 +1428,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1463,9 +1474,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1501,42 +1511,48 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / - mfma.selected_mfma.num_threads_per_blk; + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; - auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - auto a_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>( - a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); + auto a_thread_offset_m = waveId_m; - auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, - true>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -1564,27 +1580,32 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1598,19 +1619,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1622,8 +1649,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -1632,8 +1659,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -1641,36 +1668,39 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< @@ -1700,12 +1730,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( - make_tuple(problem.M, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleA, 1)); + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - // B Scale grid transposed const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleB, 1)); + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); Run( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1845,12 +1896,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + // B Scale buffer const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1886,67 +1939,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1957,7 +1985,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 auto b_block_buf_ping = make_dynamic_buffer( bit_cast(static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_buf_pong = make_dynamic_buffer( @@ -1965,7 +1993,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 auto b_block_buf_pong = make_dynamic_buffer( bit_cast(bit_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); @@ -1983,97 +2011,122 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - // B scale - static constexpr auto mfma = - MfmaSelector{}; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - const index_t ScaleSliceSizeN = NXdlPerWave; - static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockSize - 1) / ScaleBlockSize; - static constexpr auto KBlockScaleSliceSizeK = - (KPerBlock + ScaleBlockSize - 1) / ScaleBlockSize; + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + // TODO: Document initial thread mapping for more combinations of parameters - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + - (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; - auto b_thread_offset_k = - (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread; + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, - b_thread_offset_k / ScaleBlockSize)); + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; - constexpr auto b_scale_thread_slice_copy_step = - make_tuple(make_multi_index(NWaves * NPerXdl, 0), - make_multi_index(-NPerBlock, 0), - make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - b_scale_grid_desc_bn_ak, - b_scale_thread_desc, - b_scale_thread_copy, - b_scale_grid_buf, - b_scale_thread_slice_copy_step, - num_k_block_main_loop); + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); // shuffle C and write out { static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2087,19 +2140,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -2111,8 +2170,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -2121,8 +2180,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -2130,36 +2189,39 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< @@ -2189,12 +2251,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, CDataType* p_c_grid, @@ -2263,22 +2337,45 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleB, 1)); + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); Run_2Lds(p_a_grid, + p_a_scale_grid, p_b_grid, p_b_scale_grid, p_c_grid, @@ -2286,6 +2383,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 p_shared_1, problem, a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, b_grid_desc_bk0_n_bk1, b_scale_grid_desc_bn_ak, c_grid_desc_mblock_mperblock_nblock_nperblock); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp new file mode 100644 index 0000000000..a0e716ba8e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -0,0 +1,2295 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" + +namespace ck { + +#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk / + APackedSize); + + static constexpr index_t NLane = NPerXdl; + static constexpr index_t KLane = 64 / NLane; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + using ThisThreadBlock = ThisThreadBlock; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + if constexpr(IsXor) + { + constexpr auto permuted_desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + else + { + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor_packed( + make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple( + make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideC_, + k_batch_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead * NPerXdl; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset; + } + } + + // Calculate A scale offset + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * + MPerXdl / scale_pack_size_a; + + // Calculate B scale offset + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * + NPerXdl / scale_pack_size_b; + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; // New member for scale matrix offset + index_t b_scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in LDS + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave/NXdlPack -> NWave -> NXdlPack -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } +#if 0 + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + Run(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = + make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + + // lds max alignment + // constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // dummys + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), // actually the thread desc + Sequence{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bk0_n_bk1, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); + + Run_2Lds(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index bac8c32886..3e23008a5f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -76,10 +76,12 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; - static constexpr auto M01 = 1; - static constexpr auto N01 = 1; + static constexpr auto K1 = Number{}; + static constexpr auto KPerBlock = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; static constexpr auto gemm_padder = tensor_operation::device::GemmPadder{ @@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); } }(); @@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( + return make_naive_tensor_descriptor( make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); } }(); // B matrix in LDS memory, dst of blockwise copy @@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); } }(); @@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( + return make_naive_tensor_descriptor( make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); } }(); @@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load ThreadGroupTensorSliceTransfer_DirectLoad, ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, FloatA, ComputeType, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, 3, ABlockTransferSrcScalarPerVector>( @@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load ThreadGroupTensorSliceTransfer_DirectLoad, BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, FloatB, ComputeType, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, 3, BBlockTransferSrcScalarPerVector>( diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2255505985..c17b88ccea 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -260,7 +260,8 @@ struct ThreadwiseTensorSliceTransfer_v2 static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! Not divisible"); - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) { static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); } @@ -422,6 +423,240 @@ struct ThreadwiseTensorSliceTransfer_v2 SrcCoord src_coord_; }; // namespace ck +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2_gather +{ + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather( + const SrcDesc& src_desc, + const Index& src_slice_origin_idx, + const StaticallyIndexedArray& scale_gather_offsets) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)), + scale_gather_offsets_(scale_gather_offsets) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + auto adjusted_origin_idx = [&]() { + Index idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number{}]; }); + + return idx; + }(); + + src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { + constexpr auto current_dst_origin = + to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type + src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset() / PackedSize + + scale_gather_offsets_(gather_idx), + is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + + src_data_idx + i * src_scalar_step_in_vector); + constexpr auto full_dst_offset = + dst_desc.CalculateOffset(current_dst_origin) + dst_offset; + + if constexpr(InvalidElementAsNaN) + { + dst_buf(full_dst_offset) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + }); + + // printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n", + // blockIdx.y, + // threadIdx.x, + // dst_buf(Number<0>{})); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + StaticallyIndexedArray scale_gather_offsets_; +}; // namespace ck + // Assume: // 1. src_desc and dst_desc are not known at compile-time // 2. SrcBuffer and DstBuffer are DynamicBuffer @@ -1053,10 +1288,8 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, - "wrong! Not divisible"); - - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) { static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); } @@ -1236,16 +1469,16 @@ struct ThreadwiseTensorSliceTransfer_v4 { // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // DstData) - vector_type_maker_t dst_tmp_vector; + vector_type_maker_t dst_tmp_vector; // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { dst_tmp_vector.template AsType()(i) = type_convert(src_tmp_vector.template AsType()[i]); }); // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp index 96b95579f5..168f028e2a 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst } }; +template +struct lambda_wave_cluster_dimension +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if((nDim - i) == 3) + return WaveNum; + else + return 1; + } +}; + } // namespace detail } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 7ccea96dda..79e22018a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -90,7 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { - if constexpr(is_same_v, pk_i4_t>) + if constexpr((packed_size_v) > 1) { static_assert(is_same_v, remove_cvref_t>, "SrcData != DstData"); @@ -99,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); - static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); } } @@ -444,6 +445,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { static_assert(!is_same_v, pk_i4_t>, "in-register transpose is not supported for pk_i4_t"); + static_assert(!is_same_v, f4x2_pk_t>, + "in-register transpose is not supported for f4x2_pk_t"); // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index bd6fe772e4..50f1e21beb 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather dst_element_op_(dst_element_op), gather_offsets_(gather_offsets) { - if constexpr(is_same_v, pk_i4_t>) + if constexpr((packed_size_v) > 1) { static_assert(is_same_v, remove_cvref_t>, "SrcData != DstData"); @@ -105,7 +105,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); - static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); } } @@ -222,7 +223,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - const IndexType ld_offset = src_coord_.GetOffset() + gather_offset; + const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset; src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, true); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 7cd0a0fc7f..9b1ff3dbf8 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -410,8 +410,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter using dst_vector_t = typename remove_cvref_t::type; IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); - // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], - // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); dst_bufs(i).template Update( diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index b825d7ab69..7da353d9ad 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -8,6 +8,35 @@ #include "ck/utility/amd_xdlops.hpp" namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + using U = element_type_t; + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} enum struct MfmaInstr { @@ -847,6 +876,8 @@ struct mfma_type template const ScaleB& scale_b, FloatC& reg_c) const { - static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); - static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); - intrin_mfma_scale_f32_32x32x64f8f6f4::Run( - a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( + a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); } }; @@ -885,6 +914,8 @@ struct mfma_type template const ScaleB& scale_b, FloatC& reg_c) const { - static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); - static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); - intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( + a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); } }; @@ -1117,7 +1146,7 @@ struct MfmaSelector #endif } - // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 // TODO: explore optimization opportunity by using new mfma instructions on gfx950 template <> @@ -1153,6 +1182,16 @@ struct MfmaSelector { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } template <> constexpr auto GetMfma() @@ -1290,10 +1329,10 @@ struct MfmaSelector #endif } - static constexpr auto selected_mfma = mfma_type, MPerXdlops, NPerXdlops, - additional_type, + element_type_t, is_single_rate_mfma, is_scale_mfma>()>{}; @@ -1375,7 +1414,8 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); + static_assert(KPack * 2 % mfma_instr.k_per_blk == 0, + "KPack should be a multiple of k_per_blk"); } // XDL output supporting C = A * B @@ -1413,6 +1453,49 @@ struct XdlopsGemm Sequence<7>{})); } + // XDL output supporting C = A * B + // M3_N3 -> M3_M4_M5_N3 + template + __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( + const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(M2), + make_pass_through_transform(N2), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{})); + } + // transposed XDL output supporting C' = B' * A' // M2_N2 -> M2_N2_N3_N4 template @@ -1518,7 +1601,13 @@ struct XdlopsGemm }); } - template + template __device__ void Run(const FloatA& p_a_wave, const ScaleA& a_scale_thread, const FloatB& p_b_wave, @@ -1528,12 +1617,12 @@ struct XdlopsGemm static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) { - mfma_instr.template run( + mfma_instr.template run( p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread); } else { - mfma_instr.template run( + mfma_instr.template run( p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread); } }); diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 62e3220b5a..783fc661ce 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -430,7 +430,9 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; @@ -1018,18 +1020,18 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t src_element_space_size) { // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx942__) + constexpr auto dword_bytes = 4; static_assert(bytes_per_thread == dword_bytes); - -#ifndef CK_CODE_GEN_RTC - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); -#else - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); #endif - const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + + const int32x4_t src_resource = + make_wave_buffer_resource(global_base_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM @@ -1057,7 +1059,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } #endif diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index 296c1d44d7..1836e9461d 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -843,14 +843,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; static_assert(bytes_per_thread == dword_bytes); -#ifndef CK_CODE_GEN_RTC - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); -#else - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); -#endif - const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + const int32x4_t src_resource = + make_wave_buffer_resource(global_base_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ed3354dfb5..9a28c5f332 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -662,11 +662,11 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> } }; -template +template struct intrin_mfma_scale_f32_32x32x64f8f6f4; -template <> -struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB> { template __device__ static void Run(const f8x32_t& reg_a, @@ -682,11 +682,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); // XXX: Note on the scale_a and scale_b parameters: // If compiler detects that one or both scales are constant values, it will treat that @@ -719,11 +719,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 1, // blgp - 0, // OPSEL + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); // XXX: Note on the scale_a and scale_b parameters: // If compiler detects that one or both scales are constant values, it will treat that @@ -756,11 +756,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); // XXX: Note on the scale_a and scale_b parameters: // If compiler detects that one or both scales are constant values, it will treat that @@ -798,11 +798,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 2, // blgp - 0, // OPSEL + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); #else ignore = reg_a; @@ -832,11 +832,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 3, // blgp - 0, // OPSEL + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); #else ignore = reg_a; @@ -866,11 +866,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, reg_c.template AsType()[Number<0>{}], - 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 4, // blgp - 0, // OPSEL + 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 4, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); #else ignore = reg_a; @@ -881,13 +881,60 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } }; +#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 1 -template +#ifndef BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS +#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 0 +#endif + +template struct intrin_mfma_scale_f32_16x16x128f8f6f4; -template <> -struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> { + +#define V_MFMA_SCALE_F32_16X16X128_F8F6F4(OPF_F8F6F4_CTRL_A, \ + OPF_F8F6F4_CTRL_B, \ + F8F6F4_VEC_TYPE_A, \ + F8F6F4_VEC_TYPE_B, \ + OPSEL_A_L, \ + OPSEL_A_H, \ + OPSEL_B_L, \ + OPSEL_B_H) \ + if constexpr((OpselA == 1 * OPSEL_A_L + 2 * OPSEL_A_H) && \ + (OpselB == 1 * OPSEL_B_L + 2 * OPSEL_B_H)) \ + asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 " \ + "op_sel:[" #OPSEL_A_L "," #OPSEL_A_H "] " \ + "op_sel_hi:[" #OPSEL_B_L "," #OPSEL_B_H "] " \ + "cbsz:" #OPF_F8F6F4_CTRL_A " blgp:" #OPF_F8F6F4_CTRL_B \ + : "+v"(reg_c.template AsType()(Number<0>{})) \ + : "v"(bit_cast(reg_a)), \ + "v"(bit_cast(reg_b)), \ + "v"(reg_c.template AsType()[Number<0>{}]), \ + "v"(scale_a), \ + "v"(scale_b)) +#define BOOL4_CASES(F) \ + do \ + { \ + F(0, 0, 0, 0); \ + F(0, 0, 0, 1); \ + F(0, 0, 1, 0); \ + F(0, 0, 1, 1); \ + F(0, 1, 0, 0); \ + F(0, 1, 0, 1); \ + F(0, 1, 1, 0); \ + F(0, 1, 1, 1); \ + F(1, 0, 0, 0); \ + F(1, 0, 0, 1); \ + F(1, 0, 1, 0); \ + F(1, 0, 1, 1); \ + F(1, 1, 0, 0); \ + F(1, 1, 0, 1); \ + F(1, 1, 1, 0); \ + F(1, 1, 1, 1); \ + } while(0) + template __device__ static void Run(const f8x32_t& reg_a, const int32_t& scale_a, @@ -896,18 +943,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> FloatC& reg_c) { #if defined(__gfx950__) +#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); +#else +#define f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 0, int32x8_t, int32x8_t, __VA_ARGS__) + BOOL4_CASES(f8_cases); +#undef f8_cases +#endif #else ignore = reg_a; ignore = scale_a; @@ -925,18 +978,23 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> FloatC& reg_c) { #if defined(__gfx950__) +#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 1, // blgp - 0, // OPSEL + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); +#else +#define bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 1, int32x8_t, int32x8_t, __VA_ARGS__) + BOOL4_CASES(bf8_cases); +#endif #else ignore = reg_a; ignore = scale_a; @@ -954,18 +1012,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> FloatC& reg_c) { #if defined(__gfx950__) +#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 1, // blgp - 0, // OPSEL + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); +#else +#define f8bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 1, int32x8_t, int32x8_t, __VA_ARGS__) + BOOL4_CASES(f8bf8_cases); +#undef f8bf8_cases +#endif #else ignore = reg_a; ignore = scale_a; @@ -983,18 +1047,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> FloatC& reg_c) { #if defined(__gfx950__) +#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); +#else +#define bf8f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 0, int32x8_t, int32x8_t, __VA_ARGS__) + BOOL4_CASES(bf8f8_cases); +#undef bf8f8_cases +#endif #else ignore = reg_a; ignore = scale_a; @@ -1022,11 +1092,11 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 2, // blgp - 0, // OPSEL + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); #else ignore = reg_a; @@ -1055,11 +1125,11 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 3, // blgp - 0, // OPSEL + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); #else ignore = reg_a; @@ -1071,29 +1141,43 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> } template - __device__ static void Run(const f4x32_t& reg_a, - const int32_t scale_a, - const f4x32_t& reg_b, - const int32_t scale_b, - FloatC& reg_c) + __device__ static void + Run(const f4x32_t& reg_a, // misalignment between pk_f4_t, 32 and f4_t, 32 + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) { +#if 0 + if(get_thread_local_1d_id()){ + printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n", + get_thread_local_1d_id(), + *reinterpret_cast(&scale_a), *reinterpret_cast(&scale_b), + OpselA, OpselB); + } +#endif #if defined(__gfx950__) +#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS int32x4_t arg_a = bit_cast(reg_a); int32x4_t arg_b = bit_cast(reg_b); - - using arg_type = int32x8_t; - + using arg_type = int32x8_t; reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, reg_c.template AsType()[Number<0>{}], - 4, // cbsz - 4, // blgp - 0, // OPSEL + 4, // cbsz + 4, // blgp + OpselA, // OPSEL scale_a, - 0, // OPSEL + OpselB, // OPSEL scale_b); +#else +#define f4_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(4, 4, int32x4_t, int32x4_t, __VA_ARGS__) + BOOL4_CASES(f4_cases); +#undef f4_cases +#endif #else ignore = reg_a; ignore = scale_a; @@ -1102,7 +1186,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_c; #endif } -}; +#undef BOOL4_CASES +#undef V_MFMA_SCALE_F32_16X16X128_F8F6F4 +}; // namespace ck template struct intrin_mfma_f32_16x16x128f8f6f4; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 6c788fb41e..861b81b1f6 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -71,7 +71,8 @@ template + index_t KPerXDL, + bool IsF4F6 = false> struct BlockwiseGemmXdlops_pipeline_hotloop_inst { static constexpr index_t WaveSize = 64; @@ -99,14 +100,16 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst static constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1; + static constexpr index_t C_MFMA_Inst_Cycle = []() { if constexpr(NPerXDL == 16) { - return KPerXDL == 128 ? 32 : 16; + return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp; } else if constexpr(NPerXDL == 32) { - return KPerXDL == 64 ? 64 : 32; + return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp; } }(); @@ -123,7 +126,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " - "%d, %d\n C MFMA inst: %d\n" + "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n" "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " "%d/ %d\n", A_Buffer_Load_Inst_Num, @@ -133,6 +136,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, C_MFMA_Inst_Num, + C_MFMA_Inst_Cycle, A_LDS_Read_Width, B_LDS_Read_Width, ALDSWriteWidth, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index b90ff237dc..ad9bb45158 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -43,8 +43,8 @@ struct f4x2_pk_t using type = uint8_t; type data; - __host__ __device__ f4x2_pk_t() : data{type{}} {} - __host__ __device__ f4x2_pk_t(type init) : data{init} {} + __host__ __device__ constexpr f4x2_pk_t() : data{type{}} {} + __host__ __device__ constexpr f4x2_pk_t(const type init) : data{init} {} template __host__ __device__ inline type unpack(Number) const @@ -165,6 +165,17 @@ inline constexpr bool is_native_type() is_same::value || is_same::value || is_same::value; } +template +struct is_f8f6f4 +{ + static constexpr bool value = + is_same_v || is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v; +}; +template +inline constexpr bool is_f8f6f4_v = is_f8f6f4::value; + // scalar_type template struct scalar_type; @@ -303,105 +314,87 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -// Default behavior for types that do not need special handling template -struct packed_type -{ - using type = T; - static constexpr index_t packed_size = 1; // number of packed elements -}; - -template <> -struct packed_type -{ - using type = pk_i4_t; - static constexpr index_t packed_size = 2; // number of packed elements -}; - -template <> -struct packed_type -{ - using type = f4x2_pk_t; - static constexpr index_t packed_size = 2; // number of packed elements -}; - -template <> -struct packed_type -{ - using type = f6x32_pk_t; - static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements -}; - -template <> -struct packed_type -{ - using type = bf6x32_pk_t; - static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements -}; - -template -using packed_type_t = typename packed_type::type; - -// Check if the type has packed type specialization -template -inline constexpr bool has_packed_type_v = !is_same_v, T>; - -template -struct element_type +struct packed_type_info { private: - static constexpr auto get_element_type() + static constexpr auto get_packed_type_info() { using U = remove_cvref_t; if constexpr(is_same_v) - return int4_t{}; + return ck::Tuple, int4_t>{}; else if constexpr(is_same_v) - return f4_t{}; + return ck::Tuple, f4_t>{}; else if constexpr(is_same_v) - return f6_t{}; + return ck::Tuple, f6_t>{}; else if constexpr(is_same_v) - return bf6_t{}; + return ck::Tuple, bf6_t>{}; else if constexpr(is_same_v) - return f6_t{}; + return ck::Tuple, f6_t>{}; else if constexpr(is_same_v) - return bf6_t{}; + return ck::Tuple, bf6_t>{}; + else + return ck::Tuple, T>{}; + } + + public: + using element_type = remove_cvref_t{}))>; + static constexpr auto packed_size = + static_cast(get_packed_type_info().At(ck::Number<0>{})); +}; +template +using element_type_t = typename packed_type_info::element_type; + +template +inline constexpr index_t packed_size_v = packed_type_info::packed_size; + +template +inline constexpr bool is_packed_type_v = packed_size_v > 1; + +template +struct packed_type_maker +{ + private: + static constexpr auto get_packed_type() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + { + static_assert(N == 0 || N == 2, "Packed size N for int4_t must be 2."); + return pk_i4_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 2, "Packed size N for f4_t must be 2."); + return f4x2_pk_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 16 || N == 32, "Packed size N for f6_t must be 16 or 32."); + if constexpr(N == 16) + return f6x16_pk_t{}; + else if constexpr(N == 0 || N == 32) + return f6x32_pk_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 16 || N == 32, + "Packed size N for bf6_t must be 16 or 32."); + if constexpr(N == 16) + return bf6x16_pk_t{}; + else if constexpr(N == 0 || N == 32) + return bf6x32_pk_t{}; + } else return T{}; } public: - using type = decltype(get_element_type()); -}; -template -using element_type_t = typename element_type::type; - -template -inline constexpr bool is_packed_type_v = - has_packed_type_v>&& is_same_v>>; - -template -struct packed_size -{ - private: - static constexpr auto get_packed_size() - { - using U = remove_cvref_t; - if constexpr(is_packed_type_v) - return Number>::packed_size>{}; - else - return Number::packed_size>{}; - } - - public: - using type = decltype(get_packed_size()); - static constexpr auto value = get_packed_size(); + using packed_type = remove_cvref_t; }; -template -using packed_size_t = typename packed_size::type; - -template -inline constexpr index_t packed_size_v = packed_size::value; +template +using packed_type_t = typename packed_type_maker::packed_type; #if defined(_WIN32) using int64_t = long long; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 65eed0624c..049221cea1 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1330,6 +1330,12 @@ struct nnvb_data_t_selector using type = pk_i4_t::type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f4x2_pk_t::type; +}; + template struct non_native_vector_base< T, @@ -2222,6 +2228,7 @@ using f6x32_t = typename vector_type::type; using bf6x16_t = typename vector_type::type; using bf6x32_t = typename vector_type::type; +using e8m0x4_bexp_t = typename vector_type::type; // pack int4 using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index a11963cb47..16213173f3 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/functional.hpp" #include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" namespace ck { @@ -70,4 +71,44 @@ struct static_for<0, N, 1> : detail::make_applier using detail::make_applier::operator(); }; +template +struct static_for_range +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(Is{}), ...); + } +}; + +template +struct static_for_product; +template +struct static_for_product> : public static_for_range +{ +}; +template +struct static_for_product, Rest...> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + static_for_product>{}([&](auto i0) { // + static_for_product{}([&](auto... is) { // + f(i0, is...); + }); + }); + } +}; + +struct identity +{ + template + __host__ __device__ constexpr T&& operator()(T&& arg) const noexcept + { + return forward(arg); + } +}; + } // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 75f35d762c..a7fa64d710 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -5,14 +5,22 @@ namespace ck { +template +struct constant +{ + using value_type = decltype(v); + using type = constant; // using injected-class-name + static constexpr value_type value = v; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + template -struct integral_constant +struct integral_constant : constant { static constexpr T value = v; typedef T value_type; typedef integral_constant type; - __host__ __device__ constexpr operator value_type() const noexcept { return value; } - __host__ __device__ constexpr value_type operator()() const noexcept { return value; } }; template diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 9b1321dea3..5865f1dd78 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1586,6 +1586,11 @@ inline __host__ __device__ f4x2_t type_convert(float2_t x) return f4_convert_rne(x); #endif } +template <> +inline __host__ __device__ f4x2_pk_t type_convert(float2_t x) +{ + return static_cast(type_convert(x)); +} // convert vector of 32 fp32 to vector of 32 fp4 template <> diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 1a1b729394..7d06d871a9 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -112,7 +112,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); - return a_lds_block_desc; + return a_lds_block_desc; #endif } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 3fc39911dd..6a2b007ef5 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -77,33 +77,34 @@ struct ReferenceMXGemm : public device::BaseOperator ComputeTypeA, ComputeTypeB>; - Tensor a_m_k_scaled(arg.a_m_k_.mDesc); - Tensor b_k_n_scaled(arg.b_k_n_.mDesc); + const ck::index_t M = arg.a_m_k_.mDesc.GetLengths()[0]; + const ck::index_t N = arg.b_k_n_.mDesc.GetLengths()[1]; + assert(arg.a_m_k_.mDesc.GetLengths()[1] == arg.b_k_n_.mDesc.GetLengths()[0]); + const ck::index_t K = arg.a_m_k_.mDesc.GetLengths()[1]; + const ck::index_t SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1]; + Tensor a_m_k_scaled(HostTensorDescriptor({M, K}, {K, 1})); + Tensor b_k_n_scaled(HostTensorDescriptor({K, N}, {1, K})); + // printf("K: %d\n", K); - const auto M = arg.a_m_k_.mDesc.GetLengths()[0]; - const auto N = arg.b_k_n_.mDesc.GetLengths()[1]; - const auto K = arg.a_m_k_.mDesc.GetLengths()[1]; - const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1]; - - for(size_t m = 0; m < M; m++) + for(int m = 0; m < M; m++) { - for(size_t k = 0; k < K; k++) + for(int k = 0; k < K; k++) { if constexpr(is_same_v) { - // TODO: add support for ColMajor layout as well if(k % 2 == 1) - a_m_k_scaled(m, k) = - type_convert( - f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * - type_convert( - arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); - else - a_m_k_scaled(m, k) = - type_convert( - f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * - type_convert( - arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + { + continue; + } + // TODO: add support for ColMajor layout as well + auto a_pack = arg.a_m_k_(m, k); + auto a_scale = + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + auto a_f4_lo = f4_t(a_pack.template unpack<>(Number<0>{})); + auto a_f4_hi = f4_t(a_pack.template unpack<>(Number<1>{})); + + a_m_k_scaled(m, k) = type_convert(a_f4_lo) * a_scale; + a_m_k_scaled(m, k + 1) = type_convert(a_f4_hi) * a_scale; } else if constexpr(is_same_v || is_same_v || @@ -124,25 +125,24 @@ struct ReferenceMXGemm : public device::BaseOperator } } - for(size_t n = 0; n < N; n++) + for(int n = 0; n < N; n++) { - for(size_t k = 0; k < K; k++) + for(int k = 0; k < K; k++) { if constexpr(is_same_v) { // TODO: add support for RowMajor layout as well if(k % 2 == 1) - b_k_n_scaled(k, n) = - type_convert( - f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * - type_convert( - arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); - else - b_k_n_scaled(k, n) = - type_convert( - f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * - type_convert( - arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + { + continue; + } + auto b_pack = arg.b_k_n_(k, n); + auto b_scale = + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + auto b_f4_lo = f4_t(b_pack.template unpack<>(Number<0>{})); + auto b_f4_hi = f4_t(b_pack.template unpack<>(Number<1>{})); + b_k_n_scaled(k, n) = type_convert(b_f4_lo) * b_scale; + b_k_n_scaled(k + 1, n) = type_convert(b_f4_hi) * b_scale; } else if constexpr(is_same_v || is_same_v || diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 0cb2c2bd79..274273d576 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -23,6 +23,10 @@ using I32 = int32_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; using I4 = ck::pk_i4_t; +using F4 = ck::f4x2_pk_t; + +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Empty_Tuple = ck::Tuple<>; @@ -42,8 +46,9 @@ using BF16_Tuple = ck::Tuple; using F32_F32_Tuple = ck::Tuple; // GEMM layout -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; using Row_Tuple = ck::Tuple; using Row_Row_Tuple = ck::Tuple; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp index 4af5143f45..ec75a0cfb0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -22,9 +22,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( Col, Row, F8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, F16, 32, PassThrough, @@ -36,23 +36,37 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( Col, Row, F8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, BF16, 32, PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances( + std::vector>>& instances); + void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( std::vector> + ck::tensor_operation::element_wise::PassThrough>, + enable_if_t>> // non-weight-pre-shuffle { using DeviceOp = DeviceGemmMX && is_same_v && + is_same_v) + { + add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs); + } } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -153,6 +173,73 @@ struct DeviceOperationInstanceFactory< } }; +void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMX, + enable_if_t>> +{ + using DeviceOp = DeviceGemmMX; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp index 4c12e515e8..a99416f80b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -34,19 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 16, 16, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 16, 16, 16, 16, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp index 94f75d0e0f..7e8daef867 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp @@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp index 0f4ebc350b..976b7bbe86 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp @@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp index d2bc9351b6..bf65b9af76 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp @@ -31,8 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp index 2c208c01f3..2a65566f8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp @@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt index 0442bed130..bb67a9edae 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -6,6 +6,8 @@ list(APPEND GEMM_MX_INSTANCES device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp + device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp ) @@ -13,6 +15,8 @@ set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp index 8dc21cbf1f..c5a44281df 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp @@ -13,12 +13,13 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using BF8 = bf8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using BF8 = bf8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -40,17 +41,19 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< +#if 0 // TODO: Fix RRR // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on +#endif >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp index 2b6ccdbeda..e865b2f7df 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( Row, Row, BF8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, F16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp new file mode 100644 index 0000000000..03ea71883a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F4 = f4x2_pk_t; +using F16 = half_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; +using MFMA = tensor_layout::gemm::MFMA; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMPadding = GemmSpecialization::MPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp new file mode 100644 index 0000000000..d955148d2c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..1ebb400fdd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F4 = f4x2_pk_t; +using F16 = half_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMPadding = GemmSpecialization::MPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..597879c414 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp index d3f74b2907..c9bc4d25bb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -13,11 +13,12 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -39,19 +40,21 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< +#if 0 // TODO: Fix CCR // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on +#endif >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp index c75e779fea..4f9c372c93 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( Col, Row, F8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, BF16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp index ac09df7ea2..3645026c60 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -13,11 +13,12 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp index 05914e06b5..a4c3451c47 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( Col, Row, F8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, BF16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp index 68363de523..f7ef5562e4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -13,11 +13,12 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp index f4e59cf92d..1cacee7aea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( Col, Row, F8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, F16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp index f0a54ee400..0b1f08474b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -37,30 +37,30 @@ using device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = st //#######################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#######################################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 4, 16, 16, 16, 1, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 8, 8, 16, 16, 1, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 4, 32, 16, 16, 1, 2, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 16, 8, 16, 16, 1, 2, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 4, 32, 16, 16, 1, 2, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4> + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 16, 8, 16, 16, 1, 2, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4> // clang-format on >; diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp new file mode 100644 index 0000000000..8135bf4475 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -0,0 +1,534 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace profiler { + +#if 1 +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} +#endif + +template +bool profile_gemm_mx_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + using tensor_operation::device::instance::Col; + using tensor_operation::device::instance::E8M0; + using tensor_operation::device::instance::E8M0PK; + using tensor_operation::device::instance::MFMA; + using tensor_operation::device::instance::Row; + + constexpr bool BPreShuffle = is_same_v; + using BRefLayout = conditional_t; + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + using XDataType = E8M0; + using XPackedDataType = E8M0PK; + using AScaleLayout = Row; + using BScaleLayout = Col; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + return HostTensorDescriptor({row, col}, {stride, 1}); + else + return HostTensorDescriptor({row, col}, {1, stride}); + }; + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + auto Scale_Padded_M = (M + 32 - 1) / 32 * 32; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + auto b_k_n = + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); + auto b_input = b_k_n; + if constexpr(BPreShuffle) + b_input = std::make_shared>( + f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size + + // scales for A and B + Tensor a_m_k_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + // shuffled scales for A and B + Tensor a_shuffled_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_shuffled_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::size_t total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + + a_shuffled_scale.GetElementSpaceSizeInBytes() + + b_shuffled_scale.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n->mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + auto a_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else + return ck::type_convert(x); + }; + auto b_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else + return ck::type_convert(x); + }; + + switch(init_method) + { + case 0: // Initializations for development and debugging + ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{b_data_element(0.5f)}(*b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); + if(do_log) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {0.5}" << std::endl; + std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + case 1: + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + b_k_n->GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + break; + + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + + b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + break; + } + +#if 1 + preShuffleScaleBuffer>(a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if constexpr(BPreShuffle) + { + int NPerXdl = 16; // Fixed 16 + preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); + } +#endif + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_log > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); + + if(do_log > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); + b_device_buf.ToDevice(b_input->mData.data()); + b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); + + if(do_log > 0) + std::cout << "Done." << std::endl; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMX; + std::cout << "finding op instances..." << std::endl; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm< // + ADataType, + BDataType, + CDataType, + float, // AccDataType + XDataType, + AElementOp, + BElementOp, + CElementOp, + float, // ComputeTypeA + float // ComputeTypeB + >; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + *b_k_n, + b_k_n_scale, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + std::optional best_op_object_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + bool pass = true; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0 + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_log) + { + + if(init_method == 0) + { + auto expected = static_cast(K); + auto computed = type_convert(c_m_n_device_result(0, 12)); + + pass = pass & (std::abs(expected - computed) <= 0.0f); + std::cout << "\nExpected vs Computed: " << expected << " vs " + << computed << ((pass) ? " (PASSED!)" : " (FAILED!)") + << std::endl + << std::endl; + } + else + { + if constexpr(is_same_v || + is_same_v) + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << "\n"; + else + std::cout << "A: WIP PRINT PACKED TYPE\n"; + LogRangeAsType(std::cout << "a_scale : ", a_m_k_scale.mData, ",") + << "\n"; + if constexpr(is_same_v || + is_same_v) + LogRangeAsType(std::cout << "b : ", b_k_n->mData, ",") + << "\n"; + else + std::cout << "B: WIP PRINT PACKED TYPE\n"; + LogRangeAsType(std::cout << "b_scale: ", b_k_n_scale.mData, ",") + << "\n"; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << "\n"; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + + std::string op_name = op_ptr->GetTypeString(); + std::optional op_obj_name = op_ptr->GetObjectName(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + + // scaling of partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + // TODO: fp6? + std::size_t num_btype = sizeof(ADataType) * M * K / packed_size_v + + sizeof(BDataType) * K * N / packed_size_v + + sizeof(CDataType) * M * N + + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops && ave_time > 1e-10) + { + best_op_name = op_name; + best_op_object_name = op_obj_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + std::cout << " ALayout = " << ALayout::name; + std::cout << " BLayout = " << BLayout::name; + std::cout << " CLayout = " << CLayout::name; + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + if(best_op_object_name) + std::cout << best_op_object_name.value() << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 4f4a1f5356..72a12e718c 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -63,6 +63,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) endif() + if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") + list(APPEND PROFILER_OPS profile_gemm_mx.cpp) + endif() list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) @@ -168,6 +171,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) endif() + if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") + list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) + endif() list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) diff --git a/profiler/src/profile_gemm_mx.cpp b/profiler/src/profile_gemm_mx.cpp new file mode 100644 index 0000000000..9fd6f29464 --- /dev/null +++ b/profiler/src/profile_gemm_mx.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_mx_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + MK_MFMA_MN, // 2 +}; + +enum struct GemmDataType +{ + F4_F4_F16, // 0 + F8_F8_F16, // 1 + F8_F8_BF16, // 2 +}; + +#define OP_NAME "gemm_mx" +#define OP_DESC "GEMM_mx" + +int profile_gemm_mx(int argc, char* argv[]) +{ + if(argc != 11 && argc != 14 && argc != 18) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f4->f16 ;\n"); + printf(" 1: fp8->f16 ;\n"); + printf(" 2: fp8->bf16 )\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n] ;\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n] ;\n"); + printf(" 2: A[k, m] * BPreShuff = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("optional:\n"); + printf("arg14: number of kbatch (default 1)\n"); + printf("arg15: number of warm-up cycles (default 1)\n"); + printf("arg16: number of iterations (default 10)\n"); + printf("arg17: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + int arg_index = 2; + const auto data_type = static_cast(std::stoi(argv[arg_index++])); + const auto layout = static_cast(std::stoi(argv[arg_index++])); + const bool do_verification = std::stoi(argv[arg_index++]); + const int init_method = std::stoi(argv[arg_index++]); + const bool do_log = std::stoi(argv[arg_index++]); + const bool time_kernel = std::stoi(argv[arg_index++]); + + const int M = std::stoi(argv[arg_index++]); + const int N = std::stoi(argv[arg_index++]); + const int K = std::stoi(argv[arg_index++]); + + int StrideA = -1, StrideB = -1, StrideC = -1; + if(argc > arg_index) + { + StrideA = std::stoi(argv[arg_index++]); + StrideB = std::stoi(argv[arg_index++]); + StrideC = std::stoi(argv[arg_index++]); + } + + int KBatch = 1; + int n_warmup = 1; + int n_iter = 10; + uint64_t rotating = 0; + if(argc > arg_index) + { + KBatch = std::stoi(argv[arg_index++]); + n_warmup = std::stoi(argv[arg_index++]); + n_iter = std::stoi(argv[arg_index++]); + rotating = std::stoull(argv[arg_index++]) * 1024 * 1024; + } + + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F4 = ck::f4x2_pk_t; + using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using MFMA = ck::tensor_layout::gemm::MFMA; + + auto profile = + [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_mx_impl( // + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F4{}, F4{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_MFMA_MN) + { + return profile(F4{}, F4{}, F16{}, Row{}, MFMA{}, Row{}); + } + else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, BF16{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_mx); diff --git a/test/gemm_mx/test_gemm_mx.cpp b/test/gemm_mx/test_gemm_mx.cpp index 2c976a217f..a3449cb1bb 100644 --- a/test/gemm_mx/test_gemm_mx.cpp +++ b/test/gemm_mx/test_gemm_mx.cpp @@ -12,7 +12,7 @@ using F8 = ck::f8_t; using BF8 = ck::bf8_t; using F6 = ck::f6_t; using BF6 = ck::bf6_t; -using F4 = ck::f4_t; +using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -52,22 +52,23 @@ class TestGemmMX_KM_NK }; // clang-format off -using KernelTypes_F8_MK_NK = ::testing::Types< +using KernelTypes_MK_NK = ::testing::Types< #if defined(CK_ENABLE_FP8) // ADataType, BDataType, CDataType, ScaleBlockSize std::tuple< F8, F8, F16, ck::Number<32> >, - std::tuple< F8, F8, BF16, ck::Number<32> > + std::tuple< F8, F8, BF16, ck::Number<32> >, #endif + std::tuple< F4, F4, F16, ck::Number<32> > >; -using KernelTypes_BF8_F8_MK_KN = ::testing::Types< +using KernelTypes_MK_KN = ::testing::Types< #if defined(CK_ENABLE_FP8) // ADataType, BDataType, CDataType, ScaleBlockSize std::tuple< BF8, F8, F16, ck::Number<32> > #endif >; -using KernelTypes_F8_KM_NK = ::testing::Types< +using KernelTypes_KM_NK = ::testing::Types< #if defined(CK_ENABLE_FP8) // ADataType, BDataType, CDataType, ScaleBlockSize std::tuple< F8, F8, BF16, ck::Number<32> > @@ -75,9 +76,9 @@ using KernelTypes_F8_KM_NK = ::testing::Types< >; // clang-format on -TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_F8_MK_NK); -TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_BF8_F8_MK_KN); -TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_F8_KM_NK); +TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_KM_NK); /// A: RowMajor /// B: ColMajor @@ -214,7 +215,8 @@ TYPED_TEST(TestGemmMX_MK_KN, Large) TYPED_TEST(TestGemmMX_KM_NK, SmallN) { constexpr int M = 256; - std::vector Ns{1, 2, 3, 4, 5, 6}; + std::vector Ns{32, 64}; + // std::vector Ns{1, 2, 3, 4, 5, 6}; constexpr int K = 512; constexpr int StrideA = M; @@ -222,16 +224,16 @@ TYPED_TEST(TestGemmMX_KM_NK, SmallN) for(int N : Ns) { - const auto new_N = N * 8; - const auto StrideC = new_N; - this->Run(M, new_N, K, StrideA, StrideB, StrideC); + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); } } TYPED_TEST(TestGemmMX_KM_NK, MidLargeN) { constexpr int M = 256; - std::vector Ns{127, 255, 312, 799, 1573}; + std::vector Ns{128, 256, 2048}; + // std::vector Ns{127, 255, 312, 799, 1573}; constexpr int K = 512; constexpr int StrideA = M; @@ -239,9 +241,8 @@ TYPED_TEST(TestGemmMX_KM_NK, MidLargeN) for(int N : Ns) { - const auto new_N = (N + 7) / 8 * 8; - const auto StrideC = new_N; - this->Run(M, new_N, K, StrideA, StrideB, StrideC); + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); } } diff --git a/test/gemm_mx/test_gemm_mx_util.hpp b/test/gemm_mx/test_gemm_mx_util.hpp index 02833daeb4..675a3de127 100644 --- a/test/gemm_mx/test_gemm_mx_util.hpp +++ b/test/gemm_mx/test_gemm_mx_util.hpp @@ -18,6 +18,7 @@ #include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "profiler/profile_gemm_mx_impl.hpp" namespace ck { namespace test { @@ -27,401 +28,6 @@ using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; } // namespace -template -bool profile_gemm_mx_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - int KBatch, - int n_warmup, - int n_iter, - uint64_t rotating = 0) -{ - if(K % ScaleBlockSize != 0) - { - throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); - }; - - using ScaleDataType = e8m0_bexp_t; - using AScaleLayout = Row; - using BScaleLayout = Col; - - bool pass = true; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - auto f_get_default_stride = - [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { - if(stride == -1) - { - // give a chance if stride is -1, return a default packed stride - if constexpr(std::is_same_v) - { - return static_cast(col); - } - else - { - return static_cast(row); - } - } - else - return static_cast(stride); - }; - - auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); - auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - - Tensor a_m_k_scale(f_host_tensor_descriptor( - M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A - Tensor b_k_n_scale(f_host_tensor_descriptor( - K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B - - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::size_t total_gemm_needed = - a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + - a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes(); - int rotating_count = std::max( - 1, - std::min(n_iter, - static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; - std::cout << "rotating count: " << rotating_count << std::endl; - - switch(init_method) - { - case 0: // Initializations for development and debugging - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); - if(do_log) - { - std::cout << "Init A = {1}" << std::endl; - std::cout << "Init A scale = {2.0}" << std::endl; - std::cout << "Init B = {0.5}" << std::endl; - std::cout << "Init B scale = {1.0}" << std::endl; - std::cout << "Expect C = {K}" << std::endl; - } - break; - - case 1: - - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] - b_k_n.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] - - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - - break; - - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1] - - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - break; - } - - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - - if(do_log > 0) - std::cout << "Device memory allocation..." << std::endl; - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a_scale_device_buf(sizeof(ScaleDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b_scale_device_buf(sizeof(ScaleDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - - if(do_log > 0) - std::cout << "Upload data to device..." << std::endl; - a_device_buf.ToDevice(a_m_k.mData.data()); - a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); - - if(do_log > 0) - std::cout << "Done." << std::endl; - - using DeviceOp = ck::tensor_operation::device::DeviceGemmMX; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - // Run reference GEMM - if(do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceMXGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - a_m_k_scale, - b_k_n, - b_k_n_scale, - c_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - } - - std::string best_op_name; - std::optional best_op_object_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - float best_kbatch = 0; - - // profile device GEMM instances - for(auto& op_ptr : op_ptrs) - { - std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0 - - if(KBatch > 0) - { - kbatch_list = {KBatch}; - } - - for(std::size_t i = 0; i < kbatch_list.size(); i++) - { - auto kbatch_curr = kbatch_list[i]; - - auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(a_scale_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(b_scale_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - Scale_Stride_AM, - StrideB, - Scale_Stride_BN, - StrideC, - kbatch_curr, - a_element_op, - b_element_op, - c_element_op); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - - // re-init C to zero before profiling next kernel - c_device_buf.SetZero(); - - invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, false, 0, n_warmup, n_iter}); - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_log) - { - - if(init_method == 0) - { - auto expected = static_cast(K); - auto computed = type_convert(c_m_n_device_result(0, 12)); - - pass = pass & (std::abs(expected - computed) <= 0.0f); - std::cout << "\nExpected vs Computed: " << expected << " vs " - << computed << ((pass) ? " (PASSED!)" : " (FAILED!)") - << std::endl - << std::endl; - } - else - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "a_scale : ", a_m_k_scale.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "b_scale: ", b_k_n_scale.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } - } - - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); - } - - std::string op_name = op_ptr->GetTypeString(); - std::optional op_obj_name = op_ptr->GetObjectName(); - - float ave_time = invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, - time_kernel, - 0, - n_warmup, - n_iter, - rotating_count > 1, - rotating_count}); - - // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + - // scaling of partial sums(K/ScaleBlockSize)] - // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize - std::size_t flop = - std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + - sizeof(ScaleDataType) * (M * K + K * N) / ScaleBlockSize; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " - << kbatch_curr << std::endl; - - if(tflops > best_tflops && ave_time > 1e-10) - { - best_op_name = op_name; - best_op_object_name = op_obj_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - best_kbatch = kbatch_curr; - } - } - else - { - std::cout << op_ptr->GetTypeString() << " does not support this problem" - << std::endl; - } - } - } - - if constexpr(is_same::value) - { - std::cout << "Best Perf for datatype = f32"; - } - else if constexpr(is_same::value) - { - std::cout << "Best Perf for datatype = f16"; - } - else if constexpr(is_same::value) - { - std::cout << "Best Perf for datatype = bf16"; - } - else if constexpr(is_same::value) - { - std::cout << "Best Perf for datatype = int8"; - } - - if constexpr(is_same::value) - { - std::cout << " ALayout = RowMajor"; - } - else if constexpr(is_same::value) - { - std::cout << " ALayout = ColumnMajor"; - } - - if constexpr(is_same::value) - { - std::cout << " BLayout = RowMajor"; - } - else if constexpr(is_same::value) - { - std::cout << " BLayout = ColumnMajor"; - } - - std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA - << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch - << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec - << " GB/s, " << best_op_name << std::endl; - - if(best_op_object_name) - std::cout << best_op_object_name.value() << std::endl; - - return pass; -} - template class TestGemmMX : public testing::Test { @@ -471,25 +77,25 @@ class TestGemmMX : public testing::Test int n_warmup = 1, int n_iter = 10) { - bool pass = ck::test::profile_gemm_mx_impl(verify_, - init_method_, - log_, - bench_, - M, - N, - K, - StrideA, - StrideB, - StrideC, - kbatch, - n_warmup, - n_iter); + bool pass = ck::profiler::profile_gemm_mx_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); EXPECT_TRUE(pass); } }; diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 4cab411cb4..21a0484d19 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -74,7 +74,11 @@ struct mfma_scale_type_selector<16, 16> AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); + op.template run<16, 16, 0, 0>(fragA, + ck::utils::get_exponent_value(scale_a[Number<0>{}]), + fragB, + ck::utils::get_exponent_value(scale_b[Number<0>{}]), + fragAcc); } }; @@ -93,7 +97,11 @@ struct mfma_scale_type_selector<32, 32> AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); + op.template run<32, 32, 0, 0>(fragA, + ck::utils::get_exponent_value(scale_a[Number<0>{}]), + fragB, + ck::utils::get_exponent_value(scale_b[Number<0>{}]), + fragAcc); } }; @@ -921,14 +929,12 @@ template -__global__ void matmul(const typename packed_type::type* a, - const typename packed_type::type* b, - CType* c) +__global__ void matmul(const packed_type_t* a, const packed_type_t* b, CType* c) { - using PackedAType = typename packed_type::type; - constexpr auto packed_size_a = packed_type::packed_size; - using PackedBType = typename packed_type::type; - constexpr auto packed_size_b = packed_type::packed_size; + using PackedAType = packed_type_t; + constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + constexpr auto packed_size_b = packed_size_v; constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); @@ -1005,9 +1011,9 @@ __global__ void matmul(const packed_type_t* a, CType* c) { using PackedAType = packed_type_t; - constexpr auto packed_size_a = packed_size_v; + constexpr auto packed_size_a = packed_size_v; using PackedBType = packed_type_t; - constexpr auto packed_size_b = packed_size_v; + constexpr auto packed_size_b = packed_size_v; constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); @@ -1181,10 +1187,10 @@ template struct TestMXMFMA { - using PackedAType = typename packed_type::type; - static constexpr auto packed_size_a = packed_type::packed_size; - using PackedBType = typename packed_type::type; - static constexpr auto packed_size_b = packed_type::packed_size; + using PackedAType = packed_type_t; + static constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + static constexpr auto packed_size_b = packed_size_v; auto PrepareGemmTensors(const GemmParams& params, index_t init) { @@ -1384,11 +1390,10 @@ template struct TestMFMA { - - using PackedAType = typename packed_type::type; - static constexpr auto packed_size_a = packed_type::packed_size; - using PackedBType = typename packed_type::type; - static constexpr auto packed_size_b = packed_type::packed_size; + using PackedAType = packed_type_t; + static constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + static constexpr auto packed_size_b = packed_size_v; auto PrepareGemmTensors(const GemmParams& params, index_t init) {