From b9c513b30dcc59265016b1bbc4ccccca1334fd07 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 9 Sep 2025 11:22:36 +0800 Subject: [PATCH] Extend XDL kernel to Support RDNA3/4 - Part 3 (#2723) Support Wave32/Wave64 in all XDL Kernels 1. Add following helper function/marocs in device_base.hpp - GET_NXDL_PER_WAVE_IMPL and GetNXdlPerWave2 - INVOKER_RUN_IMPL and INVOKER_RUN3_IMPL - IsValidGemmCompilationParameter and IS_VALID_COMPILATION_PARAMETER_IMPL 2. Replace GridwiseGemm to GridwiseGemm32 and GridwiseGemm64, and use one of them according to current GPU target 3. Move gridwise gemm related variable from Argument member to local variable in RunImp - It is to avoid duplicated GridwiseGemm::CheckValidity 4. Add IsValidGemmCompilationParameter to all XDL kernels. Know issues: - DeviceBatchedGemmXdl and DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle are incorrect on gfx11. - DeviceGemmMultipleDLayernorm_Xdl_CShuffle are incorrect on both gfx11 and gfx12. [ROCm/composable_kernel commit: 0f8e33f81120e5734ef47a6a169ad85c6560cbd8] --- ...mm_softmax_gemm_operation_xdl_cshuffle.cpp | 82 +-- ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 36 +- ...wd_multiple_abd_operation_xdl_cshuffle.cpp | 27 +- codegen/src/utils.cpp | 5 +- .../test/grouped_conv_fwd_multiple_d_v1.cpp | 3 +- .../test/grouped_conv_fwd_multiple_d_v2.cpp | 3 +- .../test/grouped_conv_fwd_multiple_d_v3.cpp | 3 +- .../test/grouped_conv_fwd_multiple_d_v4.cpp | 3 +- include/ck/host_utility/device_prop.hpp | 28 + .../block/blockwise_gemm_pipeline_xdlops.hpp | 6 +- .../block/blockwise_gemm_smfmac_xdlops.hpp | 6 +- .../gpu/device/device_base.hpp | 146 ++++ ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 298 ++++---- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 175 +++-- .../device_batched_gemm_e_permute_xdl.hpp | 98 +-- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 149 ++-- .../impl/device_batched_gemm_multi_d_xdl.hpp | 176 +++-- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 248 ++++--- ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 356 +++++---- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 242 +++--- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 192 +++-- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 234 ++++-- .../device/impl/device_batched_gemm_xdl.hpp | 105 ++- ...evice_batched_gemm_xdl_fpAintB_b_scale.hpp | 250 ++++--- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 70 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 159 ++-- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 131 ++-- ...rd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 30 +- ...ice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 51 +- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 104 +-- ...shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 89 ++- ...onv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 106 +-- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 38 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 69 +- ...device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 51 +- ...vice_gemm_bias_add_reduce_xdl_cshuffle.hpp | 111 +-- .../device_gemm_multiple_abd_xdl_cshuffle.hpp | 43 +- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 196 +++-- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 200 ++--- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 196 +++-- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 51 +- ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 60 +- ...mm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp | 47 +- ...ultiple_d_xdl_cshuffle_v3_b_preshuffle.hpp | 58 +- ...xdl_cshuffle_v3_blockscale_bpreshuffle.hpp | 50 +- .../impl/device_gemm_reduce_xdl_cshuffle.hpp | 152 ++-- .../gpu/device/impl/device_gemm_xdl.hpp | 39 +- .../device/impl/device_gemm_xdl_cshuffle.hpp | 41 +- ...vice_gemm_xdl_cshuffle_lds_direct_load.hpp | 52 +- .../device_gemm_xdl_cshuffle_streamk_v3.hpp | 416 +++++++---- .../impl/device_gemm_xdl_cshuffle_v2.hpp | 40 +- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 76 +- ...vice_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 51 +- .../device_gemm_xdl_cshuffle_v3_b_scale.hpp | 50 +- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 251 ++++--- .../impl/device_gemm_xdl_cshuffle_v3r1.hpp | 66 +- .../device_gemm_xdl_layernorm_cshuffle.hpp | 155 ++-- .../impl/device_gemm_xdl_skip_b_lds.hpp | 88 ++- .../impl/device_gemm_xdl_splitk_c_shuffle.hpp | 116 ++- ...m_xdl_splitk_c_shuffle_lds_direct_load.hpp | 121 +-- .../device/impl/device_gemm_xdl_streamk.hpp | 143 +++- .../device_gemm_xdl_waveletmodel_cshuffle.hpp | 103 ++- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 185 +++-- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 311 +++++--- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 145 ++-- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 215 ++++-- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 139 ++-- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 217 ++++-- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 487 +++++++----- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 256 ++++--- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 276 ++++--- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 239 +++--- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 195 ++--- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 79 +- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 690 ++++++++++-------- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 189 +++-- .../device/impl/device_grouped_gemm_xdl.hpp | 289 +++++--- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 296 ++++---- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 190 +++-- .../gpu/device/impl/device_moe_gemm.hpp | 54 +- .../impl/device_moe_gemm_blockscale.hpp | 165 +++-- .../gpu/device/impl/device_moe_mx_gemm.hpp | 52 +- .../device/impl/device_moe_mx_gemm_bns.hpp | 51 +- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 161 ++-- ...tk_contraction_multiple_d_xdl_cshuffle.hpp | 187 +++-- .../gpu/grid/block_to_ctile_map.hpp | 3 +- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 4 +- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 22 + ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 62 +- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 22 +- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 16 + ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 67 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 3 + ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 4 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 3 + ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 37 +- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 4 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 56 +- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 4 +- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 4 +- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 2 + .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 42 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 39 +- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 50 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 77 +- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 52 +- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 58 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 64 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 112 ++- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 36 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 107 ++- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 68 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 62 +- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 62 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 63 +- ...ridwise_gemm_xdl_waveletmodel_cshuffle.hpp | 19 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 93 ++- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 71 +- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 51 +- .../gpu/grid/gridwise_gemm_xdlops_streamk.hpp | 85 ++- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 128 ++-- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 52 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 51 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 50 +- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 57 +- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 61 +- .../gpu/grid/gridwise_moe_gemm.hpp | 76 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 84 ++- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 46 +- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 42 +- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 80 +- 131 files changed, 8731 insertions(+), 5329 deletions(-) diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp index 6029ab0c7d..f233794ec1 100644 --- a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" #include "ck/host/stringutils.hpp" @@ -76,28 +76,28 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| // | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| // | | | | | | | | | | | Wave| Wave| Wave| | - { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, - { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, - { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, - { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, - { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, - { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, - { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, - { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, // Padded fallback kernel - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1}, // Irregular k - { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, - { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, - { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, - { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, - { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + { 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1}, + { 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1}, + { 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1}, // clang-format on }; @@ -200,28 +200,28 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // _MBlock_MWaveMPerXdl| ScalarPerVector // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1,16>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1,16>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1,16>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1,16>, 4}, + { S<1, 32, 1, 8>, 4}, // Padded fallback kernel - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // Irregular k - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // clang-format on }; diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fe556615e0..b6cae670fe 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/stringutils.hpp" @@ -81,16 +81,16 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| // | | | | | | | | Wave| Wave| Stage| // | | | | | | | | | | | - { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1}, - { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1}, - { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, - { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, + { 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1}, + { 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1}, + { 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1}, + { 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1}, // Irregular tile - { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, + { 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1}, // clang-format on }; @@ -194,14 +194,14 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // _MBlock_MWaveMPerXdl| ScalarPerVector // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 4>, 8}, - { S<1, 16, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 4>, 4}, + { S<1, 16, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // Irregular tile { S<1, 16, 1, 4>, 1}, // clang-format on diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index a2f322c50f..26988255c3 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include @@ -55,12 +55,12 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr // Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| // | | | | | | | | Wave| Wave| Stage| // | | | | | | | | | | | - { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, - { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, - { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + { 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1}, + { 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1}, + { 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1} // clang-format on }; @@ -116,11 +116,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | { S<1, 16, 1, 4>, 1}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 16>, 4}, + { S<1, 32, 1, 8>, 4}, { S<1, 16, 1, 4>, 1}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1, 8>, 8} + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1, 8>, 4} // clang-format on }; @@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}( constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); // GridwiseGemm - using GridwiseGemm = DeviceConv::GridwiseGemm; - + using GridwiseGemm = ck::conditional_t; static constexpr auto I0 = ck::Number<0>{}; ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index c15a9fd7d3..4cfe7a117f 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/utils.hpp" @@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y) const std::unordered_set& get_xdlop_archs() { - static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx942"}; + static std::unordered_set supported_archs{ + "gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"}; return supported_archs; } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 9902caab04..15365aadf1 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index 205283e7aa..d7ff793cb8 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index 2b83af2432..1129dbc015 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index fbe27e9c8b..5696178f68 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 2bc5a4414e..2e949bb1df 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -75,6 +75,34 @@ inline bool is_xdl_supported() ; } +template +inline bool is_xdl_wmma_supported() +{ + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") + { + return true; + } +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) + else if(is_gfx12_supported() || is_gfx11_supported()) + { + if constexpr((MPerXDL != 16) || (NPerXDL != 16)) + { + return false; + } + if constexpr(sizeof(ADataType) > 2 || sizeof(BDataType) > 2) + { + return false; + } + return true; + } +#endif + else + { + return false; + } +} + inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 613886453b..b729271680 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -108,10 +108,8 @@ struct BlockwiseGemmXdlops_pipeline_v4 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static_assert(MWaves > 0); - static_assert(NWaves > 0); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp index c553a57672..55e856b641 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -49,10 +49,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static_assert(MWaves > 0); - static_assert(NWaves > 0); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index c946abb77d..e7ce7cbcf5 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -11,6 +11,7 @@ #include "ck/stream_config.hpp" #endif +#include "ck/utility/get_id.hpp" namespace ck { namespace tensor_operation { @@ -46,6 +47,151 @@ namespace device { #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL #endif +template +static constexpr auto GetNXdlPerWave2() +{ + constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32; + constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_); + static_assert(MWaves > 0); + + constexpr index_t NWaves = Waves / MWaves; + if constexpr(NWaves == 0) + { + return 0; + } + else + { + if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0) + { + return NPerBlock_ / (NWaves * NPerXDL_); + } + else + { + return 0; + } + } +} + +#define GET_NXDL_PER_WAVE_IMPL \ + template \ + static constexpr auto GetNXdlPerWave() \ + { \ + return GetNXdlPerWave2(); \ + } + +#define INVOKER_RUN_IMPL \ + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ + { \ + if(get_warp_size() == 64) \ + { \ + if constexpr(NXdlPerWave64 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + else \ + { \ + if constexpr(NXdlPerWave32 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + return 0; \ + } + +#define INVOKER_RUN3_IMPL \ + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ + { \ + if(get_warp_size() == 64) \ + { \ + if constexpr(NXdlPerWave64 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + else \ + { \ + if constexpr(NXdlPerWave32 > 0) \ + { \ + return RunImp( \ + reinterpret_cast(arg), \ + stream_config); \ + } \ + } \ + return 0; \ + } + +template +__device__ static bool constexpr IsValidGemmCompilationParameter() +{ +#if defined(__gfx11__) || defined(__gfx12__) + if constexpr(MPerXdl != 16 || NPerXdl != 16) + { + return false; + } +#endif + +#if defined(__gfx11__) + constexpr bool SupportMemOp = CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set; +#else + constexpr bool SupportMemOp = + sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set); +#endif + if constexpr(SupportMemOp == false) + { + return false; + } + + if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + if constexpr(MWaves > 0 && NWaves > 0) + { + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + return WaveSize == get_warp_size(); + } + } + return false; +} + +#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \ + template \ + __device__ static bool constexpr IsValidCompilationParameter() \ + { \ + return ck::tensor_operation::device::IsValidGemmCompilationParameter< \ + BlockSize, \ + MPerBlock, \ + NPerBlock, \ + MPerXdl, \ + NPerXdl, \ + MXdlPerWave, \ + NXdlPerWave, \ + CDataType_, \ + CGlobalMemoryDataOperation_>(); \ + } + #ifndef CK_CODE_GEN_RTC struct BaseArgument { diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c71153768d..157e475267 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -94,79 +94,83 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - if constexpr(isMultiA || isMultiB) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - AsPointer p_as_grid_grp; - BsPointer p_bs_grid_grp; + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + DsPointer p_ds_grid_grp; - static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); - static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - GridwiseGemm::template Run( - p_as_grid_grp, - p_bs_grid_grp, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); - } - else - { - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run( - p_as_grid + a_batch_offset, - p_bs_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } } #else ignore = p_as_grid; @@ -353,6 +357,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ComputeDataType> { using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -470,10 +477,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - using GridwiseGemm = ck::conditional_t< + template + using GridwiseGemmBase = ck::conditional_t< isMultiA || isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle, GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; @@ -481,31 +491,74 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< - decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>; using BGridPointer = remove_cvref_t< - decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument { + template + __host__ __device__ void init_ds_e_grid_desc() + { + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } __device__ __host__ Argument( APointers p_as, BPointers p_bs, @@ -549,12 +602,12 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_batch_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -655,43 +708,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; // populate desc for Ds/E - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); - - if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave64 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + init_ds_e_grid_desc(); } } else { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave32 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + init_ds_e_grid_desc(); } } } @@ -700,7 +728,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // pointers (tuple if multi AB, pointer if no) AGridPointer p_as_grid_; BGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -746,7 +774,31 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ck::Array input_left_pads_; ck::Array input_right_pads_; }; - + template + static __device__ __host__ bool check_gemm_validity(const Argument& arg) + { + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } static __device__ __host__ bool IsSupportedArgument(const Argument& arg) { namespace ctc = tensor_layout::convolution; @@ -898,27 +950,21 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // check Gridwise GEMM - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - // Genarate tuples with the same descriptors - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); - return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave64 > 0) + { + return check_gemm_validity(arg); + } } else { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave32 > 0) + { + return check_gemm_validity(arg); + } } + return false; } static __device__ __host__ auto MakeArgument( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index f59ea3efde..be09d1b505 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,44 +56,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - FloatDsPointer p_ds_grid_grp; + FloatDsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -214,6 +218,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -546,7 +554,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -567,7 +576,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -589,28 +598,49 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -642,12 +672,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -677,19 +707,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle }); // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } // for sanity check of vector memory access @@ -719,7 +749,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -767,7 +797,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -836,6 +867,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle return launch_kernel(integral_constant{}); } } + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, @@ -847,16 +879,35 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 8a8cf54e42..3d232b50bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -74,34 +76,38 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -172,6 +178,10 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -334,7 +344,8 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -359,7 +370,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t{}, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -499,7 +501,9 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -511,7 +515,9 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index b23d864f5c..0ddcd63b2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,36 +59,40 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -185,6 +189,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -346,7 +354,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm + using GridwiseGemmBase = GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -373,7 +382,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -440,8 +451,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { throw std::runtime_error("wrong! unsupported argument"); } - + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; @@ -551,7 +552,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -587,11 +605,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm()) { return false; } - // Note: we need raw lengths since threadwise copy can not handle vector load when part of // vector is out of bounds const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; @@ -617,11 +634,29 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 1f8c6b1508..9aff562744 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -82,45 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + DsPointer p_ds_grid_grp; - DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -197,6 +200,10 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -326,7 +333,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -347,7 +355,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -420,13 +448,13 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -445,19 +473,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } } @@ -474,7 +502,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -577,6 +606,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index ea5668d765..bb4af1a8df 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,7 @@ template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; - }); + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); - static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; - }); + static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); + p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; + }); - GridwiseGemm::template Run(p_a0_grid + a_batch_offset, - p_b0_grid + b_batch_offset, - p_d0s_grid, - p_b1_grid + b1_batch_offset, - p_d1s_grid, - p_e1_grid + c_batch_offset, - p_shared, - a0_element_op, - b0_element_op, - cde0_element_op, - b1_element_op, - cde1_element_op, - a0_grid_desc_ak0_m_ak1, - b0_grid_desc_bk0_n_bk1, - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - b1_grid_desc_bk0_n_bk1, - d1s_grid_desc_mblock_mperblock_nblock_nperblock, - e1_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_e1tile_map); + GridwiseGemm::template Run( + p_a0_grid + a_batch_offset, + p_b0_grid + b_batch_offset, + p_d0s_grid, + p_b1_grid + b1_batch_offset, + p_d1s_grid, + p_e1_grid + c_batch_offset, + p_shared, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op, + a0_grid_desc_ak0_m_ak1, + b0_grid_desc_bk0_n_bk1, + d0s_griddesc_m_n, + b1_grid_desc_bk0_n_bk1, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_e1tile_map); + } #else ignore = p_a0_grid; ignore = p_b0_grid; @@ -129,7 +133,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = cde1_element_op; ignore = a0_grid_desc_ak0_m_ak1; ignore = b0_grid_desc_bk0_n_bk1; - ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = d0s_griddesc_m_n; ignore = b1_grid_desc_bk0_n_bk1; ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock; @@ -231,6 +235,21 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; + static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2(); + static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2(); + static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); @@ -443,7 +462,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, // TODO: distinguish A/B datatype Acc0DataType, D0sDataType, @@ -475,7 +495,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle B1K1, Gemm0MPerXdl, Gemm0NPerXdl, - Gemm0MXdlPerWave, + Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, @@ -509,15 +529,17 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using A0GridDesc_AK0_M_AK1 = - remove_cvref_t; using B0GridDesc_BK0_N_BK1 = - remove_cvref_t; using B1GridDesc_BK0_N_BK1 = - remove_cvref_t; // Argument @@ -565,15 +587,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle e1_grid_desc_m_n_{ DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideE1)}, a0_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, b0_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + GridwiseGemm64::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, b1_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, - d1s_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, + block_2_e1tile_map_{GridwiseGemm64::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, a0_element_op_{a0_element_op}, b0_element_op_{b0_element_op}, cde0_element_op_{cde0_element_op}, @@ -597,18 +616,6 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; - std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" - << std::endl; std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } @@ -636,34 +643,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle d1s_grid_desc_m_n_(i) = DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideD1s[i]); }); - - if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_, - b0_grid_desc_n_k_, - b1_grid_desc_n_k_, - e1_grid_desc_m_n_, - block_2_e1tile_map_)) - { - e1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n_); - - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = - GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( - d0s_grid_desc_m_n_); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - d1s_grid_desc_m_n_); - } } // private: // pointers const A0DataType* p_a0_grid_; const B0DataType* p_b0_grid_; - typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + typename GridwiseGemm64::D0sGridPointer p_d0s_grid_; const B1DataType* p_b1_grid_; - typename GridwiseGemm::D1sGridPointer p_d1s_grid_; + typename GridwiseGemm64::D1sGridPointer p_d1s_grid_; E1DataType* p_e1_grid_; // tensor descriptors for problem definiton @@ -677,16 +665,10 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_; B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e1-tile map - typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_; + typename GridwiseGemm64::DefaultBlock2E1TileMap block_2_e1tile_map_; // element-wise op A0ElementwiseOperation a0_element_op_; @@ -705,7 +687,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, arg.b0_grid_desc_n_k_, @@ -716,6 +699,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e1_grid_desc_m_n_); + + auto d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.d1s_grid_desc_m_n_); + const index_t grid_size = arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_; @@ -736,7 +727,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle CDE1ElementwiseOperation, DeviceOp::A0GridDesc_AK0_M_AK1, DeviceOp::B0GridDesc_BK0_N_BK1, - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + DeviceOp::D0sGridDesc_M_N, DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, @@ -762,10 +753,10 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle arg.cde1_element_op_, arg.a0_grid_desc_ak0_m_ak1_, arg.b0_grid_desc_bk0_n_bk1_, - arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + arg.d0s_grid_desc_m_n_, arg.b1_grid_desc_bk0_n_bk1_, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_e1tile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_); @@ -783,6 +774,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle } } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(Gemm0MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(Gemm0MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -813,7 +823,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -836,11 +846,29 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return false; } - return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, - arg.b0_grid_desc_n_k_, - arg.b1_grid_desc_n_k_, - arg.e1_grid_desc_m_n_, - arg.block_2_e1tile_map_); + if(get_warp_size() == 64) + { + if constexpr(Gemm0MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + } + else + { + if constexpr(Gemm0MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 64d5fbd509..7c8c93cd91 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -37,35 +37,38 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - // populate pointer, desc for Ds - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - // D pointer - karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; - }); + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -83,39 +86,42 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - // populate pointer, desc for Ds - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - // D pointer - karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; - }); + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid + c_batch_offset, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -185,12 +191,17 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; using CDataType_ = CDataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -214,7 +225,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -241,6 +252,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm = GridwiseGemm64; struct ComputePtrOffsetOfStridedBatch { @@ -270,7 +284,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { std::array ds_offset_; - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + static_for<0, NumDTensor, 1>{}([&](auto i) { ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; }); @@ -289,32 +303,33 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideC_; }; - struct Argument : public GridwiseGemm::Argument + template + struct ArgumentBase : public GridwiseGemm::Argument { index_t Batch; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - Argument() = default; - Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - std::array p_ds_grid_, - CDataType* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_, - index_t BatchStrideA_, - index_t BatchStrideB_, - const std::array& BatchStrideDs_, - index_t BatchStrideE_, - index_t Batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - index_t KBatch_) + ArgumentBase() = default; + ArgumentBase(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + index_t KBatch_) : GridwiseGemm::Argument{p_a_grid_, p_b_grid_, p_ds_grid_, @@ -336,6 +351,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { } }; + using Argument = ArgumentBase; struct ActiveWorkgroupsPerCU { @@ -359,31 +375,69 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 } }(); - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + if(get_warp_size() == 64) { - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &max_occupancy, - kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< - GridwiseGemm, - Argument, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>, - BlockSize, - dynamic_smem_size)); + if constexpr(NXdlPerWave64 > 0) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm64, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm64, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + } } else { - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &max_occupancy, - kernel_batched_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - Argument, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>, - BlockSize, - dynamic_smem_size)); + if constexpr(NXdlPerWave32 > 0) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm32, + ArgumentBase, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm32, + ArgumentBase, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + } } max_occupancy_ = std::max(1, max_occupancy); @@ -394,14 +448,16 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + using BatchGemmArgument = ArgumentBase; if(stream_config.log_level_ > 0) { arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg)) + if(!GridwiseGemm::CheckValidity(reinterpret_cast(arg))) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } @@ -423,7 +479,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 std::array DsSize; - Argument arg_ = arg; + BatchGemmArgument arg_ = reinterpret_cast(arg); const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -442,8 +498,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -480,13 +540,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 stream_config.stream_id_)); }; - ave_time = launch_and_time_kernel_with_preprocess(stream_config, + BatchGemmArgument arg_ = reinterpret_cast(arg); + ave_time = launch_and_time_kernel_with_preprocess(stream_config, clear_workspace, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - arg); + arg_); } }; @@ -515,7 +576,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -525,7 +586,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -541,7 +602,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -553,7 +614,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -567,7 +628,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -583,7 +644,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -599,7 +660,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -615,7 +676,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -630,7 +691,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -646,7 +707,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -661,7 +722,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -673,7 +734,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -687,7 +748,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -703,7 +764,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -719,7 +780,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -735,7 +796,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -750,7 +811,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -766,7 +827,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -785,7 +846,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -796,7 +857,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -810,7 +871,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -821,7 +882,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -838,7 +899,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -849,7 +910,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -863,7 +924,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -874,7 +935,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -893,7 +954,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -903,7 +964,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -915,6 +976,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -931,11 +994,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -948,8 +1014,22 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -1093,7 +1173,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index ffebad253b..d2b77a5901 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -60,41 +60,46 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { - const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); - p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; - }); + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; + }); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -172,6 +177,9 @@ template { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -517,7 +525,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO }; // GridwiseGemm - using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -546,7 +555,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -571,6 +580,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -600,32 +611,18 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, compute_base_ptr_of_batch_{ type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -638,12 +635,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -656,8 +649,24 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { @@ -680,15 +689,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl; } } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; @@ -716,28 +716,27 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO typename GridwiseGemm::DefaultBlock2CTileMap, true>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.Batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } else { @@ -759,33 +758,34 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO typename GridwiseGemm::DefaultBlock2CTileMap, false>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.Batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -802,15 +802,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index d835bb6c61..745cf2b722 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; - }); + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_d0s_grid, - p_shared, - a_element_op, - b_element_op, - c0de_element_op, - b1_element_op, - c1de_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c1_grid_desc_mblock_mperblock_nblock_nperblock, - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - block_2_ctile_map, - c0_matrix_mask); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_d0s_grid, + p_shared, + a_element_op, + b_element_op, + c0de_element_op, + b1_element_op, + c1de_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + d0s_griddesc_m_n, + block_2_ctile_map, + c0_matrix_mask); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -122,7 +126,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = b_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1; ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = d0s_griddesc_m_n; ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; @@ -218,6 +222,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle C1DEElementwiseOperation, MaskingSpec> { + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -377,7 +386,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -406,7 +416,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -441,6 +451,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument // FIXME: constness @@ -485,6 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_gs_ms_gemm1ns_strides)}, + d0s_grid_desc_m_n_{DeviceOp::MakeD0sGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides)}, a_grid_desc_g_m_k_{ Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, b_grid_desc_g_n_k_{ @@ -495,9 +509,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle c_gs_ms_gemm1ns_strides)}, d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N( acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}, - c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c0de_element_op_{c0de_element_op}, @@ -538,23 +550,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle d0s_nl_ns_lengths_strides_[i].push_back( acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); }); - - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - b1_grid_desc_bk0_n_bk1_, - c1_grid_desc_m_n_, - block_2_ctile_map_)) - { - c1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c1_grid_desc_m_n_); - - D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N( - acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}; - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = - GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( - d0s_grid_desc_m_n); - } } void Print() const @@ -578,26 +573,22 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const BDataType* p_b_grid_; const B1DataType* p_b1_grid_; CDataType* p_c_grid_; - typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + typename GridwiseGemm64::D0sGridPointer p_d0s_grid_; // tensor descriptor AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; C1GridDesc_M_N c1_grid_desc_m_n_; + D0sGridDesc_M_N d0s_grid_desc_m_n_; AGridDesc_G_M_K a_grid_desc_g_m_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_; C1GridDesc_G_M_N c1_grid_desc_g_m_n_; D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; - typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; - // block-to-c-tile map - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; // element-wise op AElementwiseOperation a_element_op_; @@ -626,12 +617,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { throw std::runtime_error("wrong! unsupported argument"); } + auto c1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c1_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_; @@ -657,7 +652,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + D0sGridDesc_M_N, typename GridwiseGemm::DefaultBlock2CTileMap, ComputeBasePtrOfStridedBatch, C0MatrixMask, @@ -681,8 +676,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.d0s_grid_desc_m_n_, arg.block_2_ctile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_, @@ -703,6 +698,25 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -724,11 +738,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // TODO ANT: Check if tensor specialization & strides mismatch // Check if C permute dimension matches GEMM + GEMM shape @@ -792,12 +805,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.b1_grid_desc_bk0_n_bk1_, - arg.c1_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 1345d2b782..77379c8fb1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1,16 +1,18 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - #ifndef __HIPCC_RTC__ #include #include +#endif + +#include "ck/utility/common_header.hpp" +#ifndef __HIPCC_RTC__ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #endif -#include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -64,37 +66,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map, - c0_matrix_mask); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + c0_matrix_mask); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -202,7 +208,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle CElementwiseOperation, MaskOutUpperTriangle> { + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -369,7 +380,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle C0MatrixMask_impl>; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -396,7 +408,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -430,6 +442,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle LoopSched, matrix_padder.PadN, MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; #ifndef __HIPCC_RTC__ // Argument @@ -466,8 +480,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle b1_grid_desc_bk0_n_bk1_{ DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, acc_element_op_{acc_element_op}, @@ -478,16 +491,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle c0_matrix_mask_{NRaw}, raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - b1_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } } // private: @@ -499,9 +502,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; AccElementwiseOperation acc_element_op_; @@ -522,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, @@ -532,7 +534,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; @@ -578,7 +582,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_, @@ -599,6 +603,25 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -707,11 +730,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle #ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // Note: we need raw lengths since threadwise copy can not handle vector load when part of // vector is out of bounds const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; @@ -719,12 +741,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_) and - IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + } + } + return false; } // polymorphic @@ -916,7 +957,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -943,7 +985,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -978,13 +1020,16 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle matrix_padder.PadN, MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; CGridDesc_M_N c_grid_desc_m_n; C0MatrixMask c0_matrix_mask; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; // element-wise op @@ -1008,30 +1053,54 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, - block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + block_2_ctile_map{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, c_grid_descriptor_mblock_mperblock_nblock_nperblock{ - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n)}, - has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop( a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, c0_matrix_mask{c.GetLength(I1)}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, b1_element_op{b1_element_op_}, c_element_op{c_element_op_}, - is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_m_n, - block_2_ctile_map) and - IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), - b_grid_desc_bk0_n_bk1.GetLength(I1), - a_grid_desc_ak0_m_ak1.GetLength(I0) * - a_grid_desc_ak0_m_ak1.GetLength(I2), - b1_grid_desc_bk0_n_bk1.GetLength(I1))} + is_valid{false} { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + is_valid = GridwiseGemm64::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1)) and + GridwiseGemm64::template IsValidCompilationParameter<>(); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + is_valid = GridwiseGemm32::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1)) and + GridwiseGemm32::template IsValidCompilationParameter<>(); + } + } } - constexpr bool IsValid() const { return is_valid; } }; @@ -1061,12 +1130,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle #ifndef __HIPCC_RTC__ assert(desc.is_valid); #endif - __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + using GridwiseGemm = conditional_t; + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; AccElementwiseOperation acc_element_op{scale}; if(desc.has_main_k_block_loop) { - Desc::GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_b1_grid, @@ -1086,7 +1158,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle } else { - Desc::GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_b1_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index d3f067f170..b8f03e742f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,36 +52,40 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto a_grid_desc_k0_m_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); - const auto b_grid_desc_k0_n_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( - karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); - const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( - karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + c_batch_offset, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = karg; #endif @@ -135,6 +139,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -172,7 +180,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -192,7 +201,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Problem = typename GridwiseGemm::Problem; + using Problem = typename GridwiseGemm64::Problem; // Argument struct Argument : public Problem, public BaseArgument @@ -255,14 +266,17 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm + float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { karg.Print(); } - if(!GridwiseGemm::CheckValidity(karg)) + typename GridwiseGemm::Problem arg( + karg.M, karg.N, karg.K, karg.StrideA, karg.StrideB, karg.StrideC); + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); @@ -293,6 +307,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm()) { return false; } - - return GridwiseGemm::CheckValidity(problem); + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return false; + } + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(problem); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(problem)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 459ebd7f35..d19810694b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -37,27 +37,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = + karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -75,31 +78,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = + karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -171,8 +177,13 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -196,7 +207,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -223,6 +234,8 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale ComputeTypeB, PermuteA, PermuteB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -277,31 +290,32 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale index_t BatchStrideScaleB_; }; - struct Argument : public GridwiseGemm::Argument + template + struct ArgumentBase : public GridwiseGemm::Argument { index_t Batch; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t StrideScaleB_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t BatchStrideScaleB_, - const BScaleDataType* p_b_scale_grid_, - index_t Batch_, - index_t KBatch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_) + ArgumentBase(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) : GridwiseGemm::Argument(p_a_grid_, p_b_grid_, p_c_grid_, @@ -323,12 +337,16 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { } }; + using Argument = ArgumentBase; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const ArgumentBase& arg, + const StreamConfig& stream_config = StreamConfig{}) { + using DeviceArgument = ArgumentBase; if(stream_config.log_level_ > 0) { arg.Print(); @@ -353,7 +371,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + DeviceArgument arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -365,7 +383,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) / BPackedSize; - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -422,7 +440,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -432,7 +450,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -448,7 +466,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -460,7 +478,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -474,7 +492,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -490,7 +508,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -506,7 +524,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -522,7 +540,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -537,7 +555,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -553,7 +571,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -568,7 +586,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -580,7 +598,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -594,7 +612,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -610,7 +628,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -626,7 +644,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -642,7 +660,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -657,7 +675,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -673,7 +691,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -692,7 +710,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -703,7 +721,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -717,7 +735,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -728,7 +746,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -745,7 +763,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -756,7 +774,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -770,7 +788,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -781,7 +799,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -800,7 +818,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -810,7 +828,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -822,6 +840,27 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using Argument32 = ArgumentBase; + return RunImp(reinterpret_cast(arg), + stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -838,11 +877,14 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -855,8 +897,22 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using Argument32 = ArgumentBase; + return GridwiseGemm32::CheckValidity(reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -1003,7 +1059,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 4934993693..4aa6c18d04 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,6 +75,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle : public DeviceCGemm { using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -118,7 +121,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle } // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, @@ -142,7 +146,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -164,13 +168,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1})); // Argument - struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem + struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm64::Problem { - using Problem = typename GridwiseGemm::Problem; + using Problem = typename GridwiseGemm64::Problem; Argument(const ADataType* p_a_grid_real_, const ADataType* p_a_grid_imag_, @@ -221,14 +227,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg)) + typename GridwiseGemm::Problem problem( + arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC); + if(!GridwiseGemm::CheckValidity(problem)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } @@ -317,7 +326,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_real, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -327,7 +336,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_imag, arg.p_aux_2_grid, - arg); + problem); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( @@ -352,7 +361,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_imag, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -362,7 +371,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_real, arg.p_aux_2_grid, - arg); + problem); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( @@ -395,7 +404,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_real, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -405,7 +414,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_imag, arg.p_aux_2_grid, - arg); + problem); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( @@ -430,7 +439,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_imag, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -440,7 +449,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_real, arg.p_aux_2_grid, - arg); + problem); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( @@ -461,6 +470,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -477,12 +488,27 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Problem problem( + arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC); + return GridwiseGemm32::CheckValidity(problem); + } + } + return false; } // polymorphic @@ -587,8 +613,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC) { - const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( - M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC); + const auto c_grid_desc_m_n = + GridwiseGemm64::MakeCGridDescriptor_M_N(M, + GridwiseGemm64::CalculateMPadded(M), + N, + GridwiseGemm64::CalculateNPadded(N), + StrideC); return c_grid_desc_m_n.GetElementSpaceSize(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 27f0a7af7c..cf5d5bd64e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,22 +55,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_as_grid; ignore = p_bs_grid; @@ -160,6 +164,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -172,7 +180,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle using ComputeDataType = EDataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, @@ -194,7 +203,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -217,6 +226,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto matrix_padder = ck::tensor_operation::device::MatrixPadder{ @@ -385,21 +396,21 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // desc for blockwise copy using AsGridDesc_AK0_M_AK1 = - remove_cvref_t; using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -427,11 +438,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle bs_grid_desc_n_k_{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, - as_grid_desc_ak0_m_ak1_{}, - bs_grid_desc_bk0_n_bk1_{}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -475,28 +482,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); }); - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, - bs_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - as_grid_desc_ak0_m_ak1_ = - GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); - - bs_grid_desc_bk0_n_bk1_ = - GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } - // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { @@ -521,9 +506,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } // pointers - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::AsGridPointer p_as_grid_; + typename GridwiseGemm64::BsGridPointer p_bs_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -532,13 +517,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; - // tensor descriptors for block/thread-wise copy - AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; - BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -564,7 +542,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, arg.bs_grid_desc_n_k_, @@ -574,7 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto as_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(arg.as_grid_desc_m_k_); + auto bs_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(arg.bs_grid_desc_n_k_); + + auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_); + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -609,10 +600,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.as_grid_desc_ak0_m_ak1_, - arg.bs_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_etile_map_); }; @@ -628,6 +619,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -638,11 +631,12 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + using A0DataType = remove_cvref_t>; + using B0DataType = remove_cvref_t>; + if(!ck::is_xdl_wmma_supported()) { return false; } - // check vector load/store { bool valid_as_access = true; @@ -713,11 +707,30 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } } - return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() > 0) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 615566a555..2a569a49e9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,23 +53,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -163,6 +166,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -314,7 +321,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -335,7 +343,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -357,24 +365,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -403,12 +413,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -425,22 +433,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); }); - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - } - // for sanity check of vector memory access tie(a_continous_dim_, a_max_read_elems_) = CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); @@ -471,7 +463,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -512,7 +504,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -523,7 +516,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle throw std::runtime_error( "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -562,8 +561,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.cde_element_op_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_etile_map_); }; @@ -577,6 +576,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -587,21 +588,39 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 23440e24f6..ff652ebefb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -72,6 +72,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using DeviceOp = DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; @@ -281,7 +285,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -300,7 +305,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ NPerXdl, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -329,8 +334,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmAtomicAddBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -349,7 +357,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ NPerXdl, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -378,6 +386,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; + // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); @@ -506,7 +517,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -635,6 +648,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -650,11 +665,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // vector load A/B matrix from global memory if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index d4f89b3e09..2be31cabed 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -35,8 +35,8 @@ template (); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = WeiDataType; using CDataType = InDataType; @@ -374,7 +378,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -386,11 +391,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -410,6 +415,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -512,7 +519,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -603,6 +611,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -618,11 +628,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -651,14 +660,30 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i])) + bool valid = false; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + if(!valid) + return false; } return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 5d039427d6..94a0dc8c84 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -68,6 +68,10 @@ struct using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -467,7 +471,8 @@ struct using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -488,7 +493,7 @@ struct NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -509,6 +514,8 @@ struct CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -541,9 +548,6 @@ struct c_grid_desc_m_n_{}, c0_grid_desc_m_n_{}, c1_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, @@ -578,27 +582,6 @@ struct c1_grid_desc_m_n_ = descs[I4]; block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c1_grid_desc_m_n_); - } } // private: @@ -612,15 +595,7 @@ struct CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -643,7 +618,8 @@ struct { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -700,6 +676,20 @@ struct float ave_time = 0; + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); + + auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c0_grid_desc_m_n_); + + auto c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c1_grid_desc_m_n_); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r3< @@ -736,9 +726,9 @@ struct arg.p_c1_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -780,9 +770,9 @@ struct arg.p_c1_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -792,6 +782,8 @@ struct return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -807,11 +799,10 @@ struct static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -851,10 +842,27 @@ struct } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 242f5cd673..3a3b29a168 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -69,6 +69,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -448,7 +452,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using C0GridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -468,7 +473,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -489,6 +494,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -520,8 +527,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c0_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -556,23 +561,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X c_grid_desc_m_n_ = descs[I2]; c0_grid_desc_m_n_ = descs[I3]; block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - } + GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } const ADataType* p_a_grid_; @@ -583,13 +572,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; InElementwiseOperation in_element_op_; @@ -613,7 +596,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -667,6 +651,15 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float ave_time = 0; + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); + + auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c0_grid_desc_m_n_); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< @@ -699,8 +692,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X arg.p_c0_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -738,8 +731,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X arg.p_c0_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -749,6 +742,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -764,11 +759,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -808,10 +802,27 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 0d295a2418..3bc5f9af03 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -36,8 +36,8 @@ template < ck::index_t NPerBlock, ck::index_t K0PerBlock, ck::index_t K1, - ck::index_t MPerXdl, - ck::index_t NPerXdl, + ck::index_t MPerXDL, + ck::index_t NPerXDL, ck::index_t MXdlPerWave, ck::index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_K0_M_K1, @@ -72,6 +72,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -433,7 +437,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -451,10 +456,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W K0PerBlock * K1, K1, // AK1 K1, // BK1 - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -475,6 +480,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -501,7 +508,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, @@ -534,28 +540,13 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W c_grid_desc_m_n_ = descs[I2]; block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - } } - const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -578,8 +569,21 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; @@ -614,35 +618,25 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W std::cout << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" "nwavenperxdl_{ " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I0) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I1) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I2) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I3) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I4) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I5) << "}" << std::endl; } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); - } - const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -679,7 +673,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -713,7 +707,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -723,6 +717,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -738,11 +734,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -782,10 +777,27 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index c7aa54f1d9..cecfa48408 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -69,6 +69,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -324,7 +328,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -340,7 +345,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -360,6 +365,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -430,7 +437,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -510,6 +518,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -525,11 +535,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -569,8 +578,23 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index dc8499fcf2..09dc2a03db 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -56,32 +56,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); - const long_index_t b_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -140,6 +143,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -263,7 +270,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, AccDataType, @@ -282,7 +290,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -302,6 +310,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 7, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); @@ -399,7 +409,9 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -507,6 +519,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -523,11 +537,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 2881036bee..d074342127 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -36,8 +36,8 @@ template (); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = WeiDataType; using CDataType = InDataType; @@ -975,7 +979,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -987,11 +992,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -1011,6 +1016,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -1216,7 +1223,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -1305,6 +1313,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1320,11 +1330,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -1354,14 +1363,30 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i])) + bool valid = false; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + if(!valid) + return false; } return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp index 5bfef60af2..3929525987 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,6 +81,10 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -383,7 +387,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -417,7 +422,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -442,6 +447,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -477,11 +484,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, @@ -489,26 +492,6 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c0_grid_desc_m_n_); - - c1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c1_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -524,15 +507,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO C0GridDesc_M_N c0_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c0_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -546,7 +521,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, @@ -555,7 +531,20 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + auto c0_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c0_grid_desc_m_n_); + + auto c1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c1_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -607,10 +596,10 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO arg.reduce_out_element_ops_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, arg.block_2_ctile_map_); } else @@ -657,16 +646,18 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO arg.reduce_out_element_ops_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -683,15 +674,31 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp index 4d9dd192c9..db62dd340f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,6 +80,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -88,7 +92,8 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD>; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -111,7 +116,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -570,6 +579,8 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -599,7 +609,22 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -739,7 +764,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector; static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector; @@ -268,8 +272,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto matrix_padder = MatrixPadder{ - GemmMPerBlock, GemmNPerBlock, GemmKPerBlock}; + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -326,7 +330,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static_assert(is_same::value); return DeviceOp:: - MakeEHGridDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>( + MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( MRaws[i], NRaws[i], DsStride[i]); }, Number{}); @@ -363,12 +367,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // We have to separate mean var descriptor for gemm and layernorm bacause of different grid // layout(different padding) using GemmMeanVarGridDesc_M_NBlock = - decltype(MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, - 1)); + decltype(MakeMeanVarDescriptor_M_N, MPerBlock, NPerBlock>(1, 1)); using GemmCountGridDesc_M_NBlock = - decltype(MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, - 1)); + decltype(MakeCountDescriptor_M_N, MPerBlock, NPerBlock>(1, 1)); using LayernormMeanVarGridDesc_M_NBlock = decltype(MakeMeanVarDescriptor_M_N, @@ -383,7 +385,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle using GammaBetaGridDesc_N = decltype(MakeDescriptor_X(1)); using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N, 1, 1>(1, 1, 1)); - using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< + template + using GridwiseGemmWelfordBase = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -401,15 +404,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, + MPerBlock, + NPerBlock, + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -432,8 +435,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle PostShuffleScalarPerVector, LoopSched, PipelineVer>; + using GridwiseGemmWelford64 = GridwiseGemmWelfordBase; + using GridwiseGemmWelford32 = GridwiseGemmWelfordBase; - using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemmWelford64::DefaultBlock2ETileMap; using GridwiseWelfordLayernorm = GridwiseWelfordSecondHalfLayernorm2d, - GemmMPerBlock, - GemmNPerBlock>(MRaw, NRaw, StrideH)}, + DeviceOp::MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( + MRaw, NRaw, StrideH)}, layernorm_e_grid_desc_m_n_{ DeviceOp::MakeEHGridDescriptor_M_N, LayernormBlockTileSize_M_N::At(0), @@ -513,11 +517,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle LayernormBlockTileSize_M_N::At(1)>( MRaw, NRaw, StrideH)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemmWelford64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemmWelford64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, block_2_etile_map_{ - GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, + GridwiseGemmWelford64::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -525,16 +529,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle MRaw_{MRaw}, NRaw_{NRaw}, KRaw_{KRaw}, - gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)}, + gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, epsilon_{static_cast(epsilon)} { // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1. gemm_mean_var_grid_desc_m_nblock_ = - DeviceOp::MakeMeanVarDescriptor_M_N, GemmMPerBlock, 1>( + DeviceOp::MakeMeanVarDescriptor_M_N, MPerBlock, 1>( MRaw, gemm_nblock_); gemm_count_grid_desc_m_nblock_ = - DeviceOp::MakeCountDescriptor_M_N, GemmMPerBlock, 1>( + DeviceOp::MakeCountDescriptor_M_N, MPerBlock, 1>( MRaw, gemm_nblock_); layernorm_mean_var_grid_desc_m_nblock_ = @@ -558,33 +562,66 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // D desc ds_grid_desc_m_n_(i) = - DeviceOp::MakeEHGridDescriptor_M_N, - GemmMPerBlock, - GemmNPerBlock>(MRaw, NRaw, StrideDs[i]); + DeviceOp::MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( + MRaw, NRaw, StrideDs[i]); }); // populate desc for Ds/E/mean/var/count - if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - gemm_e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + if(GridwiseGemmWelford64::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford64:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - gemm_e_grid_desc_m_n_); + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford64:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); - gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = - GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( - gemm_mean_var_grid_desc_m_nblock_); + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford64:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); - gemm_count_grid_desc_mblock_mperblock_nblock_ = - GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( - gemm_count_grid_desc_m_nblock_); + gemm_count_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford64:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + if(GridwiseGemmWelford32::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford32:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford32:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); + + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford32:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); + + gemm_count_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford32:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } } } @@ -602,7 +639,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; + typename GridwiseGemmWelford64::DsGridPointer p_ds_grid_; void* p_workspace_e_grid_; void* p_workspace_mean_; void* p_workspace_var_; @@ -626,15 +663,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle EHGridDesc_M_N h_grid_desc_m_n_; // tensor descriptors for block/thread-wise copy - typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemmWelford64::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + typename GridwiseGemmWelford64::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemmWelford64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemmWelford64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + typename GridwiseGemmWelford64::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock gemm_mean_var_grid_desc_mblock_mperblock_nblock_; - typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock + typename GridwiseGemmWelford64::CountGridDescriptor_MBlock_MPerBlock_NBlock gemm_count_grid_desc_mblock_mperblock_nblock_; // block-to-e-tile map @@ -657,8 +694,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle struct Invoker : public BaseInvoker { using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0; @@ -787,7 +824,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return launch_kernel(integral_constant{}); } } - + using GridwiseGemm32 = GridwiseGemmWelford32; + using GridwiseGemm64 = GridwiseGemmWelford64; + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -857,11 +896,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // check vector load/store { using Row = ck::tensor_layout::gemm::RowMajor; @@ -944,12 +982,36 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return false; } } - - return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.gemm_e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemmWelford64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemmWelford32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return false; + } + } } // polymorphic @@ -1060,9 +1122,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle" << "<" << BlockSize << ", " - << GemmMPerBlock << ", " - << GemmNPerBlock << ", " - << GemmKPerBlock << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " << AK1 << ", " << BK1 << ", " << getGemmSpecializationString(GemmSpec) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 8ae6761769..ccc9c7f9b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,26 +60,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_rs_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - qs_element_op, - rs_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - rs_grid_desc_mblock_mperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_rs_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -185,6 +189,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { using DeviceOp = DeviceGemmMultipleDMultipleR_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumRTensor = RsDataType::Size(); @@ -282,7 +290,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -312,7 +321,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -335,15 +344,17 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -370,86 +381,52 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle p_ds_grid_{}, // FIXME p_e_grid_{static_cast(p_e_grid)}, p_rs_grid_{}, // FIXME + MRaw_(MRaw), + NRaw_(NRaw), a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - rs_grid_desc_mblock_mperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, qs_element_op_{qs_element_op}, rs_element_op_{rs_element_op} { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - e_grid_desc_m_n_, - r_grid_desc_m_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - - p_ds_grid_(i) = static_cast(p_ds_grid[i]); - - const auto d_grid_desc_m_n = - DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - d_grid_desc_m_n); - }); - - static_for<0, NumRTensor, 1>{}([&](auto i) { - using RDataType = remove_cvref_t>; - - p_rs_grid_(i) = static_cast(p_rs_grid[i]); - - rs_grid_desc_mblock_mperblock_(i) = - GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); - }); - } + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + stride_ds_[i] = StrideDs[i]; + }); + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + p_rs_grid_(i) = static_cast(p_rs_grid[i]); + }); } // private: // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; - typename GridwiseGemm::RsGridPointer p_rs_grid_; - + typename GridwiseGemm64::RsGridPointer p_rs_grid_; + index_t MRaw_; + index_t NRaw_; + std::array stride_ds_; // tensor descriptors AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; EGridDesc_M_N e_grid_desc_m_n_; RGridDesc_M r_grid_desc_m_; - // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - StaticallyIndexedArray< - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - NumDTensor> - ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different - // type from E - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - StaticallyIndexedArray - rs_grid_desc_mblock_mperblock_; - // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -466,7 +443,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -476,6 +454,31 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock = {}; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock = {}; + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(arg.MRaw_, arg.NRaw_, arg.stride_ds_[i]); + ds_grid_desc_mblock_mperblock_nblock_nperblock(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + + static_for<0, NumRTensor, 1>{}([&](auto i) { + rs_grid_desc_mblock_mperblock(i) = + GridwiseGemm64::MakeRGridDescriptor_MBlock_MPerBlock(arg.r_grid_desc_m_); + }); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -526,9 +529,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle arg.rs_element_op_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.rs_grid_desc_mblock_mperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, arg.block_2_etile_map_); }; @@ -546,6 +549,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -556,16 +561,33 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.r_grid_desc_m_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index c7481997a9..c051a080ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -54,23 +54,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -161,6 +164,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD { using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -247,7 +253,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -268,7 +275,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; #ifndef __HIPCC_RTC__ // Argument @@ -337,12 +346,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -362,22 +369,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(MRaw, NRaw, StrideDs[i]); }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -393,7 +384,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -438,6 +427,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD()) + { + return false; + } + if(!IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_)) { return false; } - return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and - GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic @@ -735,18 +754,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_tuple()))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; - using Block2ETileMap = remove_cvref_t; // tensor descriptors for problem definiton @@ -790,21 +809,21 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD 0) + { + return GridwiseGemm64::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw) and + GridwiseGemm64::template IsValidCompilationParameter<>(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw) and + GridwiseGemm32::template IsValidCompilationParameter<>(); + } + } + return false; } constexpr index_t GetBlockSize() const { return BlockSize; } @@ -854,10 +894,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; if(desc.has_main_k_block_loop) { GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 010a23c66c..66e727faa5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -74,10 +74,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad BElementwiseOperation, CDEElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I1 = Number<1>{}; static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, DsLayout, @@ -104,7 +108,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, @@ -121,13 +125,17 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -191,6 +199,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad } } + INVOKER_RUN3_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -200,11 +210,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!ck::is_lds_direct_load_supported()) { return false; @@ -288,11 +297,29 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad } } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 0796614bd4..7e9020d796 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -82,9 +82,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK { static constexpr index_t NumDTensor = DsDataType::Size(); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -108,7 +112,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -169,7 +176,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -188,8 +195,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -624,6 +636,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -658,7 +675,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -781,7 +813,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -122,7 +126,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -140,7 +144,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -149,13 +153,17 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -180,7 +188,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -192,7 +200,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -293,6 +301,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -309,7 +319,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -328,7 +338,22 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -456,7 +481,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 4761ee2026..832267fdd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -82,10 +82,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, @@ -109,7 +113,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -136,15 +140,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; int GetPreShuffleParameters() override { return NPerXDL; } // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -172,7 +179,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle std::array DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -191,8 +198,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -492,6 +504,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -508,11 +522,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -531,7 +548,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -652,7 +684,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp index c446ca59ea..4dcdff9153 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -93,9 +93,13 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle CElementwiseOperation> { static constexpr index_t NumDTensor = DsDataType::Size(); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, @@ -122,7 +126,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -140,7 +144,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -149,15 +153,19 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; int GetPreShuffleParameters() override { return NPerXDL; } // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -182,7 +190,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -194,7 +202,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -322,6 +330,8 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -338,16 +348,19 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != // KPerBlock) // { // return false; // } + if(is_gfx11_supported() && arg.KBatch > 1) + { + return false; + } if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -367,7 +380,22 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -495,7 +523,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp index c88294edfa..14a1824508 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,6 +78,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -378,7 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -407,7 +412,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -432,6 +437,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -459,27 +466,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -491,11 +484,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -508,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -536,6 +526,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -563,26 +559,25 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio typename GridwiseGemm::DefaultBlock2CTileMap, true>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.block_2_ctile_map_); } else { @@ -603,31 +598,32 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio typename GridwiseGemm::DefaultBlock2CTileMap, false>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -644,15 +640,31 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index 5188ece333..2e4abe9819 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,6 +69,10 @@ struct DeviceGemmXdl : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -76,7 +80,8 @@ struct DeviceGemmXdl : public DeviceGemm{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -96,7 +101,7 @@ struct DeviceGemmXdl : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& karg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -160,6 +169,8 @@ struct DeviceGemmXdl : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index dd41f4bca0..327b9523b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,12 +80,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, @@ -109,7 +114,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -177,6 +186,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -206,7 +216,22 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index ac2e826725..8daaafaed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,9 +69,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I1 = Number<1>{}; - using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, ck::Tuple<>, @@ -98,7 +103,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -185,6 +194,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm()) { return false; } - if(!ck::is_lds_direct_load_supported()) { return false; @@ -264,11 +274,29 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp index 3171208830..c42228369f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,8 +75,13 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, @@ -98,7 +103,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; + // // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -176,8 +186,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 rotating_mem( + auto arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, arg_.M * arg_.K * sizeof(ADataType), @@ -426,6 +436,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2()) { return false; } @@ -481,7 +498,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -489,30 +521,30 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2(p_arg)); } - + template static auto - MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t streamk_sel, - index_t Grid_size, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) + MakeArgumentImp(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) { constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock; - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - int occupancy, num_cu; + index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock; + + int occupancy = 1, num_cu = 1; const auto calculate_grid_size = [&](const auto& kernel) { hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -524,185 +556,193 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - calculate_grid_size(kernel); - } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy>; calculate_grid_size(kernel); } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; + TailNumber::One>; calculate_grid_size(kernel); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; + TailNumber::Full>; calculate_grid_size(kernel); } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_2lds; calculate_grid_size(kernel); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_2lds; calculate_grid_size(kernel); } } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - calculate_grid_size(kernel); - } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - calculate_grid_size(kernel); + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } } } else { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy>; calculate_grid_size(kernel); } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); } } @@ -720,6 +760,62 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 0; + return MakeArgumentImp(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + a_op, + b_op, + c_op, + reduction_strategy); + } + else + { + constexpr bool IsValid = NXdlPerWave32 > 0; + return MakeArgumentImp(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + a_op, + b_op, + c_op, + reduction_strategy); + } + } static auto MakeInvoker() { return Invoker{}; } // polymorphic @@ -799,7 +895,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 { using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v2< ALayout, BLayout, CLayout, @@ -109,7 +113,7 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -174,6 +182,8 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -203,7 +212,22 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 1cb82d24eb..d100d96583 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -176,31 +176,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 { - template - static constexpr auto GetNXdlPerWave() - { - constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32; - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); - static_assert(MWaves > 0); - - constexpr index_t NWaves = Waves / MWaves; - if constexpr(NWaves == 0) - { - return 0; - } - else - { - if constexpr(NPerBlock % (NPerXDL * NWaves) == 0) - { - return NPerBlock / (NWaves * NPerXDL); - } - else - { - return 0; - } - } - } // GridwiseGemm + GET_NXDL_PER_WAVE_IMPL static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); @@ -284,6 +261,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 float RunImp(const typename GridwiseGemm::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) @@ -760,31 +742,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) - { - return RunImp(arg, stream_config); - } - } - else - { - if constexpr(NXdlPerWave32 > 0) - { - return RunImp( - reinterpret_cast(arg), - stream_config); - } - } - return 0; - } + INVOKER_RUN3_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -801,11 +759,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2()) { return false; } - if(arg.KBatch > 1) { if(is_gfx11_supported()) @@ -824,14 +781,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 || @@ -855,10 +804,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(arg)); } - else - { - return false; - } } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index faa235be50..ebd168a7d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -77,8 +77,13 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, CLayout, @@ -100,7 +105,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -149,7 +156,9 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -175,7 +184,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -377,6 +386,8 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -411,7 +425,22 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -516,9 +545,9 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -109,7 +114,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -156,7 +163,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -181,7 +190,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -622,6 +631,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale()) + { + return false; + } + + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } @@ -655,8 +670,23 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { - return GridwiseGemm::CheckValidity(arg); + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -783,7 +813,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale { - // GridwiseGemm - using GridwiseGemm = conditional_t< // - !is_same_v, - GridwiseGemmMX_xdl_cshuffle_v3< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>, - GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>>; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - using Argument = typename GridwiseGemm::Argument; + // GridwiseGemm + template + using GridwiseGemmMXBase = GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + template + using GridwiseGemmMXBPreshuffleBase = GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using GridwiseGemm64 = conditional_t< // + !is_same_v, + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; + using GridwiseGemm32 = conditional_t< // + !is_same_v, + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; + + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -299,7 +314,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -401,6 +416,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -578,9 +609,9 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); using PassThrough = ck::tensor_operation::element_wise::PassThrough; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -112,7 +117,7 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -152,17 +159,17 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 StrideDs_, index_t StrideC_, index_t k_batch_) - : GridwiseGemm::Argument(p_a_grid_, - p_b_grid_, - reinterpret_cast(p_c_grid_), - M_, - N_, - K_, - StrideA_, - StrideB_, - StrideC_, - k_batch_, - true), + : GridwiseGemm64::Argument(p_a_grid_, + p_b_grid_, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + true), p_ds(p_ds_), StrideDs(StrideDs_) { @@ -278,9 +285,10 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 + float RunImp(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) { - auto arg = *dynamic_cast(&arg_); + auto arg = *reinterpret_cast(&arg_); if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && std::is_same::value)) @@ -542,6 +550,8 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -571,7 +580,22 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -677,7 +701,7 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -380,7 +384,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -406,7 +411,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -424,14 +429,16 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + NXdlPerWave_, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + using Block2CTileMap = typename GridwiseGemm64::DefaultBlock2CTileMap; // Argument struct Argument : public BaseArgument @@ -464,26 +471,12 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c0_grid_desc_nblock_nperblock_{}, block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, acc_element_op_{acc_element_op}, c_element_op_{c_element_op} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - c0_grid_desc_nblock_nperblock_ = - GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); - } } // private: @@ -498,9 +491,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_N c0_grid_desc_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; Block2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -513,7 +503,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -538,7 +529,12 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + auto c0_grid_desc_nblock_nperblock = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(arg.c0_grid_desc_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -565,28 +561,27 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator Block2CTileMap, true>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_bias_, - arg.p_c0_grid_add_, - arg.p_c0_grid_gamma_, - arg.p_c0_grid_beta_, - arg.a_element_op_, - arg.b_element_op_, - arg.acc_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + arg.block_2_ctile_map_); } else { @@ -605,33 +600,33 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, Block2CTileMap, false>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_bias_, - arg.p_c0_grid_add_, - arg.p_c0_grid_gamma_, - arg.p_c0_grid_beta_, - arg.a_element_op_, - arg.b_element_op_, - arg.acc_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + arg.block_2_ctile_map_); } return ave_time; } + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -648,15 +643,31 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index 7315fe75a3..bc192b7651 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -63,6 +63,10 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -188,7 +192,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -207,7 +212,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -246,8 +253,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -323,7 +311,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 2666051c86..ef4593c320 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,6 +75,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -86,7 +90,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -107,7 +112,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -154,20 +161,20 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { - Print(karg); + Print(arg); } + typename GridwiseGemm::Argument karg(arg.p_a_grid, + arg.p_b_grid, + arg.p_c_grid, + arg.M, + arg.N, + arg.K, + arg.StrideA, + arg.StrideB, + arg.StrideC, + arg.MPadded, + arg.NPadded, + arg.KPadded, + arg.K0Padded, + arg.k_batch); const auto kbatch = karg.k_batch; - if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( @@ -227,9 +248,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(karg), b2c_map, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + arg.a_element_op, + arg.b_element_op, + arg.c_element_op); }; if(has_main_k0_block_loop) @@ -294,6 +315,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK()) + { + return false; + } + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -347,10 +389,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK PipelineVersionToString{{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}}; - str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] + str << GridwiseGemm64::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] << ", PipelineVersion: " << PipelineVersionToString[PipelineVer]; return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp index eda966c48a..f9e164e7b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -70,12 +70,17 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + template + using GridwiseGemmBase = GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, @@ -96,7 +101,7 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -134,20 +141,20 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK + void Print(const Argument_& karg) + { + karg.Print(); + } - void Print(const Argument& karg) { karg.Print(); } - - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -175,8 +186,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK(&karg); + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error( "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " @@ -199,17 +210,16 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK(karg), - b2c_map, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) @@ -274,6 +284,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK()) { return false; } - return GridwiseGemm::CheckValidity(karg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -327,10 +358,10 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - void Print(const Argument& karg) { karg.Print(); } + template + void Print(const Argument_& karg) + { + karg.Print(); + } - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& karg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -204,6 +217,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); - if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == - StreamKReductionStrategy::Reduction) + if(get_warp_size() == 64) { - return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc)); + if constexpr(GridwiseGemm64::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size( + sizeof(typename GridwiseGemm64::FloatAcc)); + } } else { - return 0; + if constexpr(GridwiseGemm32::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size( + sizeof(typename GridwiseGemm32::FloatAcc)); + } } + return 0; } void SetWorkSpacePointer(BaseArgument* pArg, @@ -243,11 +268,26 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) { return false; } - return GridwiseGemm::CheckValidity(karg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -270,12 +310,38 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK; - int occupancy, num_cu; + int num_cu; hipError_t rtn; - rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); - hip_check_error(rtn); + int occupancy = [&]() { + int occupancy_ = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm64::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm32::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + return occupancy_; + }(); hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -316,12 +382,39 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK; - int occupancy, num_cu; + int num_cu; hipError_t rtn; - rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); - hip_check_error(rtn); + + int occupancy = [&]() { + int occupancy_ = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm64::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm32::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + return occupancy_; + }(); hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -352,7 +445,11 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - e_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + e_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -132,6 +135,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm { + static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using DeviceOp = DeviceGemm_Xdl_WaveletModel_CShuffle; static constexpr auto I0 = Number<0>{}; @@ -201,7 +209,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, // TODO: distinguish A/B datatype GemmAcEDataType, CShuffleDataType, @@ -224,7 +233,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -277,22 +288,14 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} { - if(GridwiseGemm::CheckValidity( - a_grid_desc_m_k_, b_grid_desc_n_k_, e_grid_desc_m_n_, block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -316,8 +319,6 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -359,6 +361,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 5449525306..68e63fcb5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,47 +37,50 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto contraction_arg_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(contraction_args)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - - while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && - block_id < contraction_arg_ptr[group_id].block_end_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < contraction_arg_ptr[group_id].block_start_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - contraction_arg_ptr[group_id].p_a_grid_, - contraction_arg_ptr[group_id].p_b_grid_, - contraction_arg_ptr[group_id].p_ds_grid_, - contraction_arg_ptr[group_id].p_e_grid_, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, - contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, - contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - contraction_arg_ptr[group_id].block_2_etile_map_); + const index_t block_id = get_block_1d_id(); + + const auto contraction_arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(contraction_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && + block_id < contraction_arg_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < contraction_arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + contraction_arg_ptr[group_id].p_a_grid_, + contraction_arg_ptr[group_id].p_b_grid_, + contraction_arg_ptr[group_id].p_ds_grid_, + contraction_arg_ptr[group_id].p_e_grid_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].block_2_etile_map_); + } #else ignore = contraction_args; ignore = group_count; @@ -165,6 +168,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceGroupedContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -357,7 +363,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -378,7 +385,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -400,31 +407,33 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; struct GroupedContractionBlock2ETileMap { // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) { - default_block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + default_block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); block_start_ = BlockStart; } @@ -457,7 +466,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for block/thread-wise copy @@ -531,7 +540,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle contraction_descs[i].b_ns_ks_lengths, contraction_descs[i].b_ns_ks_strides); DsGridDesc_M_N ds_grid_desc_m_n; - typename GridwiseGemm::DsGridPointer p_ds_grid; + typename GridwiseGemm64::DsGridPointer p_ds_grid; // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto j) { @@ -550,19 +559,19 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle contraction_descs[i].e_ms_ns_lengths, contraction_descs[i].e_ms_ns_strides); const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n); const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n); const index_t grid_size_grp = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) .CalculateGridSize(e_grid_desc_m_n); const index_t BlockStart = grid_size_; @@ -592,11 +601,30 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle const index_t e_nz_stride = contraction_descs[i].e_ms_ns_strides[NumDimM + NumDimN - 1]; - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map); + } + } + if(valid) { contraction_multi_d_kernel_args_.push_back( {p_a_grid, @@ -642,7 +670,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -701,6 +730,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -711,11 +742,10 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - for(std::size_t i = 0; i < arg.group_count_; i++) { const auto a_grid_desc_m_k_ = arg.contraction_multi_d_device_args_[i].a_grid_desc_m_k_; @@ -744,11 +774,30 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle const auto ds_nz_stride_ = arg.contraction_multi_d_device_args_[i].ds_nz_stride_; const auto e_nz_stride_ = arg.contraction_multi_d_device_args_[i].e_nz_stride_; - if(!GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + } + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 25923235c3..57ea476ced 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -96,83 +96,64 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t KBatch) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); - - const long_index_t a_batch_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - const long_index_t a_n_offset = - CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t b_n_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; - - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = DsPointer::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - index_t left = 0; - index_t right = gemms_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && - block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); + + const long_index_t a_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsPointer::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) { - right = group_id; + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } - if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) - { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); - } - else - { - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid + a_batch_offset + a_n_offset, p_b_grid + b_batch_offset + b_n_offset, p_ds_grid_grp, @@ -191,22 +172,44 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) } else { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } } } #else @@ -316,6 +319,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : 32; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; @@ -436,7 +442,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 #define GridwiseGemmMultiDTemplateParams \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ @@ -451,7 +457,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 #define GridwiseGemmCTransposeTemplateParameters \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \ BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ @@ -463,17 +469,26 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; - using GridwiseGemmCTranspose = std::conditional_t< - CTranspose, - GridwiseGemmMultipleD_xdl_cshuffle, - GridwiseGemm>; + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmCTransposeBase = + GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + + using GridwiseGemmCTranspose64 = + std::conditional_t, + GridwiseGemm64>; + using GridwiseGemmCTranspose32 = + std::conditional_t, GridwiseGemm32>; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + return GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n); } @@ -504,14 +519,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map using Block2ETileMap = - decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); + decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; @@ -917,14 +932,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto GemmK = a_grid_desc_m_k.GetLength(I1); const bool HasMainKBlockLoop = - GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_); + GridwiseGemmCTranspose64::CalculateHasMainKBlockLoop(GemmK, k_batch_); gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = GemmArgs{a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - GridwiseGemmCTranspose:: + GridwiseGemmCTranspose64:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n), MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1069,7 +1084,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptor for problem definition @@ -1122,7 +1137,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; - template + template float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1281,7 +1298,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1372,12 +1390,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 if constexpr(IsSplitKSupported) { ave_time += - RunMultiDGemm(arg, stream_config); + RunMultiDGemm(arg, stream_config); } } else { - ave_time += RunMultiDGemm(arg, stream_config); + ave_time += RunMultiDGemm(arg, stream_config); } // Transpose from NHWGC to NGCHW @@ -1427,6 +1449,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -1437,11 +1484,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + // gfx11 doesn't support float atomic + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + return false; + } + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.k_batch_ > 1) { @@ -1598,18 +1649,46 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - if(!GridwiseGemmCTranspose::CheckValidity( - arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum] - .block_2_ctile_map_, - arg.k_batch_)) + bool valid = true; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + if(!GridwiseGemmCTranspose64::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum] + [i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) + { + valid = false; + } + } + else + { + if(!GridwiseGemmCTranspose32::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum] + [i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) + { + valid = false; + } + } + if(!valid) + { + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index b761939642..934dc7ee8e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -61,31 +61,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -125,8 +129,8 @@ template { using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -298,7 +305,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, @@ -314,11 +322,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -351,6 +359,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto MakeElementwiseInputSequence() { @@ -539,14 +549,15 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; int max_occupancy = 0; @@ -568,7 +579,26 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle true>, BlockSize, dynamic_smem_size)); - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -605,7 +635,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle a_grid_desc_kbatch_k0_m_k1_{}, b_grid_desc_kbatch_k0_n_k1_{}, ce_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, @@ -695,7 +724,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); + GridwiseGemm64::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; @@ -708,16 +737,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - ce_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( - ce_grid_desc_m_n_); - } } std::size_t GetWorkspaceSizeBytes() const @@ -733,7 +752,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; Block2CTileMap block_2_ctile_map_; @@ -786,7 +804,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, @@ -797,6 +816,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ce_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); @@ -843,7 +865,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle arg.Conv_G_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -900,6 +922,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -915,7 +939,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -977,10 +1001,27 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 4565074b3e..e38768b2fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -57,33 +57,36 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) { -#if defined(__gfx9__) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -112,38 +115,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -166,8 +172,8 @@ template ); using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -393,7 +402,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using CElementwiseGridDesc_M_N = remove_cvref_t())>; - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -412,10 +422,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle KPerBlock, K1, K1, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -440,6 +450,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -503,12 +515,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = @@ -549,7 +562,26 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle BlockSize, dynamic_smem_size)); } - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -704,10 +736,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ce_grid_desc_m_n_, - GridwiseGemm::CalculateMBlock(GemmM), - GridwiseGemm::CalculateNBlock(GemmN)); + GridwiseGemm64::CalculateMBlock(GemmM), + GridwiseGemm64::CalculateNBlock(GemmN)); if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) @@ -853,6 +885,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } + template float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); @@ -1506,7 +1539,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; auto launch_elementwise_kernel = [&]() { @@ -1626,11 +1660,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle grid_size_a); } - avg_time += RunGemmV3(arg, stream_config); + avg_time += RunGemmV3(arg, stream_config); avg_time += launch_elementwise_kernel(); return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1651,13 +1687,44 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const index_t GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + if(get_warp_size() == 64) { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else { return false; } @@ -1676,7 +1743,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } return false; } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 488dadf512..b361409e38 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -59,28 +59,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -119,8 +124,8 @@ template { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -360,7 +368,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle I1, I0>; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, @@ -376,11 +385,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -413,17 +422,20 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + static int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; int max_occupancy = 0; @@ -445,7 +457,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle false>, // Both true/false give the same occupancy. BlockSize, dynamic_smem_size)); - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -477,7 +507,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle a_grid_desc_kbatch_k0_m_k1_{}, b_grid_desc_kbatch_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, @@ -563,22 +592,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle c_grid_desc_m_n_ = descs[I2]; block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - } - if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) { @@ -671,7 +691,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; Block2CTileMap block_2_ctile_map_; @@ -731,7 +750,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; @@ -739,6 +759,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const BDataType* p_b_grid = arg.p_b_grid_; CDataType* p_e_grid = arg.p_c_grid_; + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -845,7 +869,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.Conv_G_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -893,6 +917,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -908,7 +934,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1023,10 +1049,27 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 0793285dbd..8bf188be2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -55,32 +55,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block) { -#if defined(__gfx9__) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -113,38 +116,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -173,8 +179,8 @@ template ); using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -330,7 +339,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -349,10 +359,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 K0PerBlock, K1, K1, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -377,15 +387,18 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + static int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = @@ -424,7 +437,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlockSize, dynamic_smem_size)); } - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -556,10 +588,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_, - GridwiseGemm::CalculateMBlock(GemmM), - GridwiseGemm::CalculateNBlock(GemmN)); + GridwiseGemm64::CalculateMBlock(GemmM), + GridwiseGemm64::CalculateNBlock(GemmN)); } const ADataType* p_a_grid_; @@ -617,7 +649,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); @@ -1225,6 +1258,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1245,23 +1280,57 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + if(get_warp_size() == 64) { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else { return false; } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.k_batch_ > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.k_batch_ > 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 1448914dd3..1412c960c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -101,106 +101,111 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if defined(__gfx9__) - - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); - - if constexpr(isMultiA || isMultiB) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - AsPointer p_as_grid_grp; - BsPointer p_bs_grid_grp; + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - // compute_ptr_offset_of_n_ not need BatchStrideB so - // in case of MultiA is false but isMultiB is true - // BatchStrideA_ is not tuple. - if constexpr(isMultiA) + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); + + if constexpr(isMultiA || isMultiB) { - const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}([&](auto i) { - p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; - }); + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; + }); + } + + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); } else { - const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); - static_for<0, 1, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); + const long_index_t b_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t a_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) + : 0; + const long_index_t a_n_offset = + CTranspose ? 0 + : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + + GridwiseGemm::template Run( + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); } - - const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); - - static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); - static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); - - GridwiseGemm::template Run( - p_as_grid_grp, - p_bs_grid_grp, - p_ds_grid_grp, - p_e_grid + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); - } - else - { - const long_index_t b_group_offset = - CTranspose - ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t a_group_offset = - CTranspose - ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_n_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; - const long_index_t a_n_offset = - CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - - GridwiseGemm::template Run( - p_as_grid + a_group_offset + a_n_offset, - p_bs_grid + b_group_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); } #else ignore = p_as_grid; @@ -316,6 +321,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static_assert(NumGroupsToMerge >= 1); @@ -478,7 +486,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ @@ -495,7 +503,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ @@ -511,7 +519,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ - MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ @@ -524,14 +532,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - using GridwiseGemm = std::conditional_t< - isMultiA || isMultiB, - GridwiseGemmMultipleABD_xdl_cshuffle, - GridwiseGemmMultipleD_xdl_cshuffle>; - using GridwiseGemmCTranspose = std::conditional_t< - CTranspose, - GridwiseGemmMultipleD_xdl_cshuffle, - GridwiseGemm>; + template + using GridwiseGemmMultipleABDBase = + GridwiseGemmMultipleABD_xdl_cshuffle; + template + using GridwiseGemmMultipleDBase = + GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmMultipleDCTransposeBase = + GridwiseGemmMultipleD_xdl_cshuffle; + + using GridwiseGemm64 = + std::conditional_t, + GridwiseGemmMultipleDBase>; + using GridwiseGemm32 = std::conditional_t, + GridwiseGemmMultipleDBase>; + + using GridwiseGemmCTranspose64 = + std::conditional_t, + GridwiseGemm64>; + using GridwiseGemmCTranspose32 = + std::conditional_t, + GridwiseGemm32>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = @@ -541,27 +567,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< - decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>; using BGridPointer = remove_cvref_t< - decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))>; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -643,6 +669,62 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Argument struct Argument : public BaseArgument { + template + void InitGridDesc() + { + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + bool valid = false; + if constexpr(CTranspose) + { + valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, + a_grid_desc_m_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + else + { + valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + if(valid) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + } + } + }; + Argument(APointers p_as, BPointers p_bs, const std::array& p_ds, @@ -709,13 +791,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{ - GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -822,58 +904,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_g_n_k_wos_strides_[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - // populate desc for Ds/E - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); - - if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave64 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + InitGridDesc(); } } else { - bool valid = false; - if constexpr(CTranspose) + if constexpr(NXdlPerWave32 > 0) { - valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, - a_grid_desc_m_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_); - } - else - { - valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_); - } - if(valid) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + InitGridDesc(); } } - if constexpr(NeedTransposeKernel) { // Use not modified base strides @@ -970,7 +1014,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // pointers (tuple if multi AB, pointer if no) AGridPointer p_as_grid_; BGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // for checking IsSupportedArgument() @@ -1032,6 +1076,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { using Argument = DeviceOp::Argument; + template float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -1232,7 +1277,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; @@ -1285,7 +1331,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle a_grid_size); } - avg_time += RunGemm(arg, stream_config); + avg_time += RunGemm(arg, stream_config); if constexpr(NeedTransposeKernel) { @@ -1324,6 +1370,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return avg_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1350,7 +1421,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } } - if(!ck::is_xdl_supported()) + + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1616,38 +1688,85 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // check Gridwise GEMM - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - // Genarate tuples with the same descriptors - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); - return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave64 > 0) + { + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm64::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose64::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + } } else { - if constexpr(CTranspose) + + if constexpr(NXdlPerWave32 > 0) { - return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_, - arg.a_grid_desc_m_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - else - { - return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm32::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose32::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } } } + + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index bb31d64a93..dd2e429a01 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -82,54 +82,57 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; - using DsGridPointer = typename GridwiseGemm::DsGridPointer; - DsGridPointer p_ds_grid_grp{}; + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; - }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - GridwiseGemm::template Run( - karg.p_a_grid + a_group_offset + a_n_offset, - karg.p_b_grid + b_group_offset, - p_ds_grid_grp, - karg.p_c_grid + e_group_offset + e_n_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - block_2_ctile_map, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n, - c_grid_desc_m_n); + GridwiseGemm::template Run( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -164,58 +167,61 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; - using DsGridPointer = typename GridwiseGemm::DsGridPointer; - DsGridPointer p_ds_grid_grp{}; + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; - }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_group_offset + a_n_offset, - karg.p_b_grid + b_group_offset, - p_ds_grid_grp, - karg.p_c_grid + e_group_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - block_2_ctile_map, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n, - c_grid_desc_m_n); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -318,6 +324,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -469,7 +478,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 remove_cvref_t; // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, DsLayout, @@ -493,7 +503,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -521,6 +531,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ADataType, BDataType, DoElementwiseBeforeCShuffle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // #undef GridwiseGemmV3TemplateParams @@ -860,6 +872,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { using Argument = DeviceOp::Argument; + template float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -1232,7 +1245,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; if constexpr(!isMultiABD) @@ -1288,7 +1302,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 a_grid_size); } - avg_time += RunGemm(arg, stream_config); + avg_time += RunGemm(arg, stream_config); // Transpose result back to NGCHW if constexpr(is_NGCHW_GKCYX_NGKHW() || @@ -1330,6 +1344,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1373,7 +1389,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -1383,7 +1399,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } return false; } - // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -1628,23 +1643,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{nullptr, - nullptr, - {}, - nullptr, - GemmM, - GemmN, - GemmK, - I0, - I0, - {}, - I0, - I1 /*KBatch*/, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + return GridwiseGemm64::CheckValidity(gemm_arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + return GridwiseGemm32::CheckValidity(gemm_arg); + } + } - return GridwiseGemm::CheckValidity(gemm_arg); + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index d7859dbc46..adcda93720 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -155,55 +155,60 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - DsPointer p_ds_grid_grp; + DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - RsPointer p_rs_grid_grp; + RsPointer p_rs_grid_grp; - static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); + static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); - static_for<0, NumRTensor, 1>{}( - [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); + static_for<0, NumRTensor, 1>{}( + [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_rs_grid_grp, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - qs_element_op, - rs_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - rs_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_rs_grid_grp, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + rs_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -299,6 +304,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle QsElementwiseOperation> { using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumRTensor = RsDataType::Size(); @@ -429,7 +437,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = remove_cvref_t({}, {}))>; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -459,7 +468,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -482,15 +491,17 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -546,13 +557,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle r_grid_desc_m_{ DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - rs_grid_desc_mblock_mperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_batch_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -577,58 +585,39 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - e_grid_desc_m_n_, - r_grid_desc_m_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; - // populate pointer, batch stride, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); - // D pointer - p_ds_grid_(i) = static_cast(p_ds[i]); + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; - ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i], - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads}; + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); - // D desc - ds_grid_desc_m_n_(i) = - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; - ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_(i)); - }); - - // populate pointer for Rs - static_for<0, NumRTensor, 1>{}([&](auto i) { - using RDataType = remove_cvref_t>; - - // R pointer - p_rs_grid_(i) = static_cast(p_rs[i]); - - rs_grid_desc_mblock_mperblock_(i) = - GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); - }); - } + // R pointer + p_rs_grid_(i) = static_cast(p_rs[i]); + }); } void Print() const @@ -644,9 +633,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; - typename GridwiseGemm::RsGridPointer p_rs_grid_; + typename GridwiseGemm64::RsGridPointer p_rs_grid_; ConvToGemmFwdTransformer conv_to_gemm_transformer_; @@ -660,16 +649,6 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - StaticallyIndexedArray< - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - NumDTensor> - ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different - // type from E - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - StaticallyIndexedArray - rs_grid_desc_mblock_mperblock_; // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -703,7 +682,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -715,6 +695,32 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock = {}; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock = {}; + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(i) = + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_(i)); + }); + + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + rs_grid_desc_mblock_mperblock(i) = + GridwiseGemm64::MakeRGridDescriptor_MBlock_MPerBlock(arg.r_grid_desc_m_); + }); + const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.a_g_n_c_wis_lengths_[0]; // Group count @@ -767,9 +773,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.rs_grid_desc_mblock_mperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, arg.block_2_etile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -784,6 +790,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -794,7 +802,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { namespace ctc = tensor_layout::convolution; - + if(!is_xdl_wmma_supported()) + { + return false; + } // check device if(get_device_name() == "gfx908") { @@ -812,9 +823,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } + else if(ck::is_gfx12_supported() || ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } else { - return false; + // return false; } // check ConvolutionForwardSpecialization @@ -952,11 +970,29 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } // check Gridwise GEMM - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.r_grid_desc_m_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 8f3feee1c1..25afe46690 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -52,67 +52,70 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - index_t left = 0; - index_t right = gemms_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && - block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) { - right = group_id; + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); + + using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); + DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = DsPointer::Size(); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = + gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; + }); + + GridwiseGemm::template Run( + gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, + gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, + p_ds_grid_grp, + gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].block_2_etile_map_); } - - using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); - DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = DsPointer::Size(); - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = - gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; - }); - - GridwiseGemm::template Run( - gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, - gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, - p_ds_grid_grp, - gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_kernel_args[group_id].block_2_etile_map_); #else ignore = gemm_desc_kernel_args; ignore = gemms_count; @@ -199,6 +202,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; @@ -412,25 +418,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Structure for each gemm(conv) struct GemmArgs { @@ -455,6 +464,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Argument struct Argument : public BaseArgument { + template + void init_gemm_args(const ADataType* a_ptr, + const BDataType* b_ptr, + DsPointer ds_ptr, + EDataType* e_ptr, + const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N_& ds_grid_desc_m_n, + const EGridDescriptor_M_N_& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map, + index_t BlockStart, + index_t BlockEnd) + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + gemm_desc_kernel_args_(valid_gemms_count_) = + GemmArgs{a_ptr, + b_ptr, + ds_ptr, + e_ptr, + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } + } Argument(const void* p_a, const void* p_b, const std::array& p_ds, @@ -543,7 +589,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number{}); const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); const index_t grid_size_grp = block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); @@ -553,28 +599,39 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor grid_size_ += grid_size_grp; - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) + if(get_warp_size() == 64) { - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - ds_grid_ptrs[i], - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; - - valid_gemms_count_++; + if constexpr(NXdlPerWave64 > 0) + { + init_gemm_args(a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_gemm_args(a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } } } // N is the same for all convs @@ -649,7 +706,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Invoker struct Invoker : public BaseInvoker { - float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + + using Argument = DeviceOp::Argument; + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -703,6 +763,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor } } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -754,11 +816,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return false; } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 764daf1750..f6ec0908eb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,94 +45,97 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t KBatch = 1; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) - return; - - const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; - const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; - const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); - constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); - constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); - - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t; - p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t; - p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t; - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - while(id_local < local_grid_size) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm:: - template Run( - p_as_grid_, - p_bs_grid_, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - M, - N, - K, - StrideAs, - StrideBs, - StrideDs, - StrideE, - block_2_etile_map); + const index_t KBatch = 1; - id_off += grid_size_grp; - id_local += grid_size_grp; + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; + const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + GridwiseGemm:: + template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } } #else ignore = gemm_descs_const; @@ -203,6 +206,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK CDEElementwiseOperation> { using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); @@ -215,7 +221,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static constexpr index_t NumGemmKPrefetchStage = 1; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeType, @@ -237,7 +244,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -259,7 +266,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops { @@ -508,7 +516,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); // block-to-e-tile map @@ -547,7 +555,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } const auto e_grid_desc_sum_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( + GridwiseGemm64::template MakeEGridDescriptor_M_N( sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -581,7 +589,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -667,6 +676,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 7b5dd55a8f..4188bf537d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -91,6 +91,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage CDEElementwiseOperation> { using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -105,7 +108,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage using WorkspaceDataType = float; // First stage GridwiseGEMM kernel. - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -126,7 +130,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage NPerXDL, AK1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -150,7 +154,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage LoopSched, PipelineVer, ComputeDataType>; - + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) { @@ -220,8 +225,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage Number{}); } - using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; - using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); using DsGridPointer = decltype(MakeDsGridPointer()); using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); @@ -258,7 +263,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage // Block2CTileMap configuration parameter. static constexpr index_t B2E_M01 = 8; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; - using GemmKernelArgument = typename GridwiseGemm::Argument; + using GemmKernelArgument = typename GridwiseGemm64::Argument; struct GemmTransKernelArg { @@ -355,12 +360,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_e = gemm_descs[i].stride_C_; - const index_t m_padded = GridwiseGemm::CalculateMPadded(M); - const index_t n_padded = GridwiseGemm::CalculateNPadded(N); - const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); - const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + const index_t m_padded = GridwiseGemm64::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm64::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm64::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(K, K_BATCH); - const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + const auto c_grid_desc_m_n = + GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, stride_e); DsGridDesc_M_N ds_grid_desc_m_n; DsGridPointer p_ds_grid; @@ -441,11 +447,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage { auto& karg = gemm_kernel_args_[i].karg_; - const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); - const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + const index_t k_padded = GridwiseGemm64::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(karg.K, K_BATCH); const auto c_grid_desc_m_n = - GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + GridwiseGemm64::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; @@ -565,13 +571,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// /// @return The average kernel execution time (if time measurement is enabled.) /// + template float Run(const Argument& arg, void* dev_gemm_args, void* dev_gemm_workspace, const StreamConfig& stream_config = StreamConfig{}) { auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = - CheckArgument(arg, stream_config); + CheckArgument(arg, stream_config); if(dev_gemm_args == nullptr) { @@ -593,13 +600,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(all_have_main_k_block_loop) { - ave_time = - DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + ave_time = DispatchKernel( + arg, dev_gemm_args, dev_gemm_workspace, stream_config); } else { - ave_time = - DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + ave_time = DispatchKernel( + arg, dev_gemm_args, dev_gemm_workspace, stream_config); } return ave_time; @@ -619,7 +626,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// /// @return The average kernel execution time (if time measurement is enabled.) /// - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(arg.p_dev_gemm_kargs_ == nullptr) { @@ -637,9 +645,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); + return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -647,6 +657,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } private: + template auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const { bool all_have_kbatch_gt_one, all_have_main_k_block_loop; @@ -670,7 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { - const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + const auto& gemm_arg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); if(stream_config.log_level_ > 0) { gemm_arg.Print(); @@ -721,7 +733,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); } - template + template float DispatchKernel(const Argument& arg, void* dev_gemm_kargs, void* dev_gemm_workspace, @@ -818,11 +830,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if((ck::type_convert(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -836,11 +847,27 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } bool supported = true; + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = false; + if(isWave64) + { + if constexpr(NXdlPerWave64 > 0) + { + group_arg_valid = GridwiseGemm64::CheckValidity(gemm_arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + group_arg_valid = GridwiseGemm32::CheckValidity( + reinterpret_cast(gemm_arg)); + } + } - bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 70a395f2f7..d8d688aa06 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,126 +68,91 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; - __shared__ uint8_t p_shared1[shared_size]; - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - constexpr auto NumDTensor = DsDataType::Size(); - index_t tile_id = get_block_1d_id(); - index_t tile_offset = 0; - index_t group_id = -1; - index_t group_offset = 0; - index_t grid_size_grp = 0; - - index_t gemm_tile_id_start = 0; - index_t gemm_tile_id_end = 0; - - index_t M = 0, N = 0, K = 0; - - auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); - - do +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - // Find corresponding GEMM group for our tile - while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && - group_id < group_count) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do { - group_offset += grid_size_grp; - group_id++; - - if(group_id >= group_count) - return; - - M = gemm_desc_ptr[group_id].M; - N = gemm_desc_ptr[group_id].N; - K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) { - grid_size_grp = 0; - continue; - } + group_offset += grid_size_grp; + group_id++; - b2c_tile_map = - OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); - grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + if(group_id >= group_count) + return; - gemm_tile_id_start = group_offset; - gemm_tile_id_end = group_offset + grid_size_grp; - } + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - DsGridPointer p_ds_grid; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - static constexpr index_t kbatch = 1; - static constexpr index_t k_grain = kbatch * KPerBlock; - index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - // Update tile offset if we have moved within group - b2c_tile_map.UpdateTileOffset(tile_offset); - - using Problem = typename GridwiseGemm::Problem; - auto problem = Problem(gemm_desc_ptr[group_id].M, - gemm_desc_ptr[group_id].N, - gemm_desc_ptr[group_id].K, - gemm_desc_ptr[group_id].StrideA, - gemm_desc_ptr[group_id].StrideB, - gemm_desc_ptr[group_id].StrideDs, - gemm_desc_ptr[group_id].StrideE, - kbatch); - - if(has_main_k_block_loop) - { - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + if(M == 0 || N == 0 || K == 0) { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); + grid_size_grp = 0; + continue; } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + + b2c_tile_map = OffsettedBlockToCTileMap( + LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { GridwiseGemm::template Run 2) + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { GridwiseGemm::template Run( + TailNumber::One>( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, @@ -224,16 +188,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) cde_element_op, b2c_tile_map); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { GridwiseGemm::template Run( + TailNumber::Full>( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, @@ -245,84 +205,166 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) cde_element_op, b2c_tile_map); } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) { - GridwiseGemm::template Run( + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), + static_cast(p_shared1), problem, a_element_op, b_element_op, cde_element_op, b2c_tile_map); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + else { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) - { - GridwiseGemm::template Run( + GridwiseGemm::template Run_2Lds( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), + static_cast(p_shared1), problem, a_element_op, b_element_op, @@ -331,39 +373,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) } } } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + else { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - GridwiseGemm::template Run_2Lds( + GridwiseGemm::template Run( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), - static_cast(p_shared1), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - else - { - GridwiseGemm::template Run_2Lds( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - static_cast(p_shared1), problem, a_element_op, b_element_op, @@ -371,32 +393,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) b2c_tile_map); } } - } - else - { - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - tile_id += get_grid_size(); - tile_offset += get_grid_size(); + tile_id += get_grid_size(); + tile_offset += get_grid_size(); - } while(group_id < group_count); + } while(group_id < group_count); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -467,10 +469,14 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -494,7 +500,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -519,6 +525,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -571,10 +579,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop // The oversubscription factor for the number of blocks that can simultaneously reside on // GPU. static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; - static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); - static constexpr int CU_SIMDS = 4; + // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; // Assume we want to have at most 2 waves per SIMD - static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + static int GetCuBlocks() + { + int BLOCK_WAVES = BlockSize / get_warp_size(); + return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + } }; // Invoker @@ -593,6 +606,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop /// /// @return The average kernel execution time (if time measurement is enabled.) /// + template float Run(const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config = StreamConfig{}) @@ -606,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop } float ave_time = 0; - ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); + ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); return ave_time; } @@ -624,7 +638,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop /// /// @return The average kernel execution time (if time measurement is enabled.) /// - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(arg.p_dev_gemm_args_ == nullptr) { @@ -634,9 +649,11 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_args_, stream_config); + return Run(arg, arg.p_dev_gemm_args_, stream_config); } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -644,6 +661,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop } private: + template float DispatchKernel(const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config) const @@ -686,11 +704,11 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks << ", available CUs count: " << cu_count << ", occup. grid size: " - << ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count + << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count << std::endl; } - return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS); + return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()); } template @@ -730,36 +748,19 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - bool supported = true; constexpr index_t k_batch = 1; + bool isWave64 = get_warp_size() == 64; for(index_t i = 0; i < arg.group_count_; ++i) { std::array placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); - using GridArg = typename GridwiseGemm::Argument; - GridArg gridwise_arg(nullptr, // p_a_grid, - nullptr, // p_b_grid, - placeholder_p_ds_grid, // p_ds_grid, - nullptr, // p_e_grid , - arg.gemm_descs_[i].M_, - arg.gemm_descs_[i].N_, - arg.gemm_descs_[i].K_, - arg.gemm_descs_[i].stride_A_, - arg.gemm_descs_[i].stride_B_, - stride_Ds, - arg.gemm_descs_[i].stride_C_, - k_batch, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_); - if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || @@ -768,8 +769,62 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { return false; } + if(isWave64) + { + if constexpr(NXdlPerWave64 > 0) + { + using GridArg = typename GridwiseGemm64::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); - supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); + supported = supported && GridwiseGemm64::CheckValidity(gridwise_arg); + } + else + { + supported = false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using GridArg = typename GridwiseGemm32::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + supported = supported && GridwiseGemm32::CheckValidity(gridwise_arg); + } + else + { + supported = false; + } + } } return supported; @@ -780,6 +835,67 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return IsSupportedArgument(*dynamic_cast(p_arg)); } + static int GetKernelOccupancy() + { + int occupancy = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + } + } + else + { + + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + } + } + return occupancy; + } + static auto MakeArgument(std::vector& p_As, std::vector& p_Bs, std::vector>& p_Ds, @@ -789,28 +905,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - int occupancy, num_cu; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + int occupancy = GetKernelOccupancy(); + int num_cu; hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -840,28 +936,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - int occupancy, num_cu; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + int occupancy = GetKernelOccupancy(); + int num_cu; hipDeviceProp_t dev_prop; hipDevice_t dev; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 784b2fd401..62dcbfb83b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,63 +43,70 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto arg_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(group_kernel_args)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - - while( - (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_))) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < arg_ptr[group_id].block_start_) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(group_kernel_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= arg_ptr[group_id].block_start_ && + block_id < arg_ptr[group_id].block_end_))) { - right = group_id; + if(block_id < arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); + + // per-group batch offset + const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; + const index_t g_idx = __builtin_amdgcn_readfirstlane( + (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run( + arg_ptr[group_id].p_a_grid_ + a_batch_offset, + arg_ptr[group_id].p_b_grid_ + b_batch_offset, + arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, + arg_ptr[group_id].p_c_grid_ + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg_ptr[group_id].block_2_ctile_map_, + arg_ptr[group_id].c0_matrix_mask_); } - - // per-group batch offset - const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; - const index_t g_idx = __builtin_amdgcn_readfirstlane( - (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( - arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - - GridwiseGemm::template Run( - arg_ptr[group_id].p_a_grid_ + a_batch_offset, - arg_ptr[group_id].p_b_grid_ + b_batch_offset, - arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, - arg_ptr[group_id].p_c_grid_ + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, - arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, - arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, - arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg_ptr[group_id].block_2_ctile_map_, - arg_ptr[group_id].c0_matrix_mask_); #else ignore = group_kernel_args; ignore = group_count; @@ -198,6 +205,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle CElementwiseOperation, MaskingSpec> { + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -338,7 +350,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -365,7 +378,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -399,8 +412,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle LoopSched, Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Block2CTileMap = OffsettedBlockToCTileMap; + using Block2CTileMap = OffsettedBlockToCTileMap; struct GroupKernelArg { @@ -414,7 +429,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; // batch & stride @@ -511,7 +526,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n); const index_t BlockStart = grid_size_; @@ -592,7 +607,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { @@ -664,6 +680,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -680,11 +715,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // TODO ANT: Check if tensor specialization & strides mismatch bool all_has_main_k_block_loop = true; @@ -708,7 +742,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle // Check if having main loop const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + const bool y = GridwiseGemm64::CalculateHasMainKBlockLoop(K); all_has_main_k_block_loop &= y; some_has_main_k_block_loop |= y; @@ -753,14 +787,31 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } - if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, - kernel_arg.b_grid_desc_bk0_n_bk1_, - kernel_arg.b1_grid_desc_bk0_n_bk1_, - device_arg.c_grid_desc_m_n_, - kernel_arg.block_2_ctile_map_)) + bool valid = false; + if(get_warp_size() == 64) { - return false; + if constexpr(MXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_); + } } + else + { + if constexpr(MXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_); + } + } + if(!valid) + return false; } // all gemm problems have to simultaneously meet has_main_k_block_loop or diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 2c5d1dd134..7a1944cc68 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -39,46 +39,49 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && - block_id < gemm_desc_ptr[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < gemm_desc_ptr[group_id].BlockStart_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - gemm_desc_ptr[group_id].a_ptr_, - gemm_desc_ptr[group_id].b_ptr_, - gemm_desc_ptr[group_id].ds_ptr_, - gemm_desc_ptr[group_id].e_ptr_, - p_shared, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, - gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, - gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_ptr[group_id].block_2_etile_map_); + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].block_2_etile_map_); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -145,7 +148,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm { using DeviceOp = DeviceGroupedGemm_Xdl; - + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -231,7 +236,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -252,7 +258,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; struct GroupedGemmBlock2ETileMap { using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; GroupedGemmBlock2ETileMap() { - block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); + block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); BlockStart_ = -1; } GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) { - block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); BlockStart_ = BlockStart; } @@ -334,7 +342,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + void init_gridwise_gemm_desc(const ADataType* a_ptr, + const BDataType* b_ptr, + DsPointer ds_ptr, + EDataType* e_ptr, + const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map, + index_t BlockStart, + index_t BlockEnd) + { + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + // tensor descriptors for block/thread-wise copy + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmBiasTransKernelArg{a_ptr, + b_ptr, + ds_ptr, + e_ptr, + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + }; Argument(std::vector& p_As, std::vector& p_Bs, std::vector>& p_Ds, @@ -403,7 +469,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}([&](auto j) { using DDataType = remove_cvref_t>; @@ -427,13 +493,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(M, N, StrideC); - // tensor descriptors for block/thread-wise copy - const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); - - const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - const index_t grid_size_grp = GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); @@ -447,42 +506,41 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}([&](auto j) { - ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n[j]); - }); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n); - - gemm_desc_kernel_arg_.push_back( - GemmBiasTransKernelArg{static_cast(p_As[i]), - static_cast(p_Bs[i]), - p_ds_grid, - static_cast(p_Es[i]), - a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map, - BlockStart, - BlockEnd}); + if constexpr(NXdlPerWave64 > 0) + { + init_gridwise_gemm_desc( + static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_gridwise_gemm_desc( + static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } } } } @@ -508,10 +566,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + float RunImp(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { bool has_main_k_block_loop = true; @@ -626,6 +685,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -636,11 +717,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm()) { return false; } - if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -649,12 +729,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) - return; - - const auto StrideA = gemm_desc_ptr[group_id].StrideA; - const auto StrideB = gemm_desc_ptr[group_id].StrideB; - const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumDTensor = DsDataType::Size(); - - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - - DsGridPointer p_ds_grid_; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - // D pointer - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - const index_t mn_blocks = local_grid_size / KBatch; - - while(id_local < local_grid_size) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - if constexpr(Zeroing) - { - auto barrier_count_finished = - barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; - GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } - else + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - nullptr, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { - id_off += grid_size_grp; - id_local += grid_size_grp; + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + + id_off += grid_size_grp; + id_local += grid_size_grp; + } } #else ignore = gemm_descs_const; @@ -241,6 +244,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK { using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -252,7 +258,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, AComputeType, @@ -274,7 +281,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops @@ -479,7 +488,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; @@ -550,7 +559,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); // block-to-e-tile map @@ -566,17 +575,55 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( - AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + if(get_warp_size() == 64) { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + if constexpr(NXdlPerWave64 > 0) + { + if(!GridwiseGemm64::template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid " + "setting"); + } + } + else + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + if(!GridwiseGemm32::template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid " + "setting"); + } + } + else + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } } gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ @@ -597,7 +644,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -631,7 +678,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -783,6 +831,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(cast_pointer_to_generic_address_space(gemm_descs_const)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && - block_id < gemm_desc_ptr[group_id].block_end_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - if(block_id < gemm_desc_ptr[group_id].block_start_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - gemm_desc_ptr[group_id].karg_, - static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_, - a_element_op, - b_element_op, - c_element_op); + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].karg_, + static_cast(p_shared), + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -144,6 +147,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -153,7 +159,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -174,7 +181,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; using Block2ETileMapKSplit = BlockToCTileMap_KSplit_M00_N0_M01Adapt; // Block2CTileMap configuration parameter. static constexpr index_t B2E_M01 = 8; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; - using KernelArgument = typename GridwiseGemm::Argument; + using KernelArgument = typename GridwiseGemm64::Argument; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - struct GemmTransKernelArg + template + struct GemmTransKernelArgBase { - KernelArgument karg_; + KernelArgument_ karg_; GroupedGemmBlock2ETileMap block_2_ctile_map_; index_t block_start_, block_end_; - GemmTransKernelArg() = default; - GemmTransKernelArg(KernelArgument&& karg, - GroupedGemmBlock2ETileMap&& b2c_map, - index_t block_start, - index_t block_end) + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) : karg_{karg}, block_2_ctile_map_{b2c_map}, block_start_{block_start}, @@ -224,6 +234,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; static constexpr index_t DefaultKBatch = 1; @@ -277,12 +288,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + float RunImp(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + static_assert(sizeof(typename GridwiseGemm::Argument) == + sizeof(typename GridwiseGemm64::Argument)); + index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { - const auto& karg = arg.gemm_kernel_args_[i].karg_; + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); if(stream_config.log_level_ > 0) { karg.Print(); @@ -439,7 +458,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -500,7 +519,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -513,7 +532,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -523,7 +542,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -534,6 +553,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -550,11 +591,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK()) + { + return false; + } + if(is_gfx11_supported() && arg.K_BATCH > 1) { return false; } - if((ck::type_convert(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -573,11 +617,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) + { + group_arg_valid = GridwiseGemm64::CheckValidity(a); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + group_arg_valid = GridwiseGemm32::CheckValidity( + reinterpret_cast(a)); + } + } - bool group_arg_valid = GridwiseGemm::CheckValidity(a); if(not group_arg_valid) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 27d3c378ac..748ae28a50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -87,8 +87,12 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -167,7 +173,9 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -195,7 +203,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -214,8 +222,13 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -376,6 +389,8 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -418,8 +432,22 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -553,7 +581,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + template + using GridwiseGemmBase = GridwiseMoeGemmBlockScale< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -180,7 +186,9 @@ struct DeviceMoeGemmBlockScale // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -207,7 +215,7 @@ struct DeviceMoeGemmBlockScale std::array DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -226,8 +234,13 @@ struct DeviceMoeGemmBlockScale using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -385,6 +398,8 @@ struct DeviceMoeGemmBlockScale return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -406,11 +421,10 @@ struct DeviceMoeGemmBlockScale { return false; } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -428,7 +442,22 @@ struct DeviceMoeGemmBlockScale return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -572,7 +601,7 @@ struct DeviceMoeGemmBlockScale << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp index e7be94242b..9c14106033 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp @@ -90,8 +90,12 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemmMX; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +164,9 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +194,7 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +213,13 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -333,6 +345,8 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -376,7 +389,22 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -530,7 +558,7 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemmMXBNS; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +165,9 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +195,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +214,13 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -331,6 +344,8 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -374,7 +388,22 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -528,7 +557,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB>; + template + using GridwiseGemmBase = GridwiseMoeGemmMX_BPreshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +165,9 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +195,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +214,13 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -358,6 +371,8 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -401,7 +415,22 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -555,7 +584,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - FloatDsPointer p_ds_grid_grp; + FloatDsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_akb_ak0_m_ak1, - b_grid_desc_bkb_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -190,7 +195,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEElementwiseOperation> { using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; - + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -521,7 +528,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -545,7 +553,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -567,9 +575,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // GridwiseGemm - using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + template + using GridwiseGemmAtomicAddBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -593,7 +604,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -615,19 +626,39 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; using AGridDesc_AKB_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BKB_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, + b_grid_desc_bkb_bk0_n_bk1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -659,14 +690,14 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, - a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( + a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm64::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( a_grid_desc_m_k_, split_k)}, - b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( + b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm64::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( b_grid_desc_n_k_, split_k)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{ - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -697,19 +728,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle }); // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, - b_grid_desc_bkb_bk0_n_bk1_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } // for sanity check of vector memory access @@ -755,7 +786,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -770,9 +801,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1_; BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map @@ -806,7 +837,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, arg.b_grid_desc_bkb_bk0_n_bk1_, @@ -818,6 +850,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + using GridwiseGemmAtomicAdd = + std::conditional_t, + GridwiseGemmAtomicAdd64, + GridwiseGemmAtomicAdd32>; + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); const index_t grid_size = @@ -931,6 +968,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -941,19 +980,35 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, - arg.b_grid_desc_bkb_bk0_n_bk1_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + if(!valid) + return false; // check vector access static_assert((ABlockTransferSrcVectorDim == 2 || ABlockTransferSrcVectorDim == 3) && diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 7eca68bbf8..b6bc634d74 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -169,7 +169,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ __device__ constexpr bool + CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 36dc8aa6ba..7a09b84a63 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -267,6 +267,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CShuffleDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 70c641531b..a15f11a93f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -195,6 +195,28 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + if constexpr(!valid) + { + return false; + } + + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 84d7b04495..b8f5a545aa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -255,18 +255,63 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + E1DataType, + CGlobalMemoryDataOperation_>() && + ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + E1DataType, + CGlobalMemoryDataOperation_>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, - const B0GridDesc_N_K& b0_grid_desc_n_k, - const B1GridDesc_N_K& b1_grid_desc_n_k, - const E1GridDesc_M_N& e1_grid_desc_m_n, - const Block2E1TileMap& block_2_e1tile_map) + __host__ static constexpr bool CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, + const B0GridDesc_N_K& b0_grid_desc_n_k, + const B1GridDesc_N_K& b1_grid_desc_n_k, + const E1GridDesc_M_N& e1_grid_desc_m_n, + const Block2E1TileMap& block_2_e1tile_map) { static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) && (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr((Gemm0MPerXdl * Gemm0MXdlPerWave) == 0 || + (Gemm0NXdlPerWave * Gemm0NPerXdl) == 0) + { + return false; + } + else + { + if constexpr((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) != 0) || + (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl) != 0)) + { + return false; + } + else + { + if(WaveSize != get_warp_size()) + { + return false; + } + } + } const auto M = a0_grid_desc_m_k.GetLength(I0); const auto N = b0_grid_desc_n_k.GetLength(I0); @@ -527,8 +572,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const CDE1ElementwiseOperation& cde1_element_op, const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const D0sGridDesc_M_N& d0s_griddesc_m_n, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& d1s_grid_desc_mblock_mperblock_nblock_nperblock, @@ -536,6 +580,8 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap& block_2_e1tile_map) { + const auto d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(d0s_griddesc_m_n); const auto a0_grid_buf = make_dynamic_buffer( p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b0_grid_buf = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 222cb3894c..0e8d003071 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -210,6 +210,22 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -445,11 +461,12 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const D0sGridDesc_M_N& d0s_griddesc_m_n, const Block2CTileMap& block_2_ctile_map, const C0MatrixMask& c0_matrix_mask) { + const auto d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(d0s_griddesc_m_n); const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1066,6 +1083,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // main body if constexpr(num_gemm1_k_block_inner_loop > 1) { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, make_tuple(Number{}, I0, I0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 2d00daf7f6..e0cf12e429 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -209,6 +209,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 96b737385a..0a4691b509 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -66,29 +66,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_bias_grid, - p_d0_grid, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - c1_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c0_grid_desc_mblock_mperblock_nblock_nperblock, - c1_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -252,6 +257,22 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c37ffb6263..c198711dbb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -298,6 +298,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template {}, Sequence<1>{})); } + IS_VALID_COMPILATION_PARAMETER_IMPL(FloatE) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 318ff59383..59d7f357ec 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -322,6 +322,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return true; } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -424,6 +428,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using Block2ETileMap = remove_cvref_t; + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n, @@ -433,6 +439,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, "KPerBlock must be divisible by AK1Value and BK1Value!"); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 85b5b5faab..872e1271e1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -351,6 +351,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -225,6 +229,22 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index b4848c7077..8f7aac0171 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -283,6 +283,8 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle e_grid_desc_m_n, 8, split_k); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -58,20 +61,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid, - karg.p_b_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.p_workspace_); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.p_workspace_); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1012,6 +1018,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 906bfe0912..4cd1a587e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -24,11 +24,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -48,10 +52,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) FloatC* __restrict__ p_c_grid, typename GridwiseGemm::Problem problem) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared, problem); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -542,6 +551,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Problem& problem) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index 57624b218c..ccba4d4a94 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,14 +25,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -52,12 +56,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) FloatC* p_c_grid, typename GridwiseGemm::Problem problem) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -581,6 +589,22 @@ struct GridwiseGemm_xdl_cshuffle_v2 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Problem& problem) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 5545192e3c..a6e4870ac7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1187,74 +1187,23 @@ struct GridwiseGemm_xdl_cshuffle_v3 return false; } - // Check tile size -#if defined(__gfx11__) || defined(__gfx12__) - if constexpr(MPerXdl != 16 || NPerXdl != 16) - { - return false; - } -#endif - // Check atomic caps -#if defined(__gfx11__) - constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set; -#else - constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation == - InMemoryDataOperationEnum::Set); -#endif - if constexpr(SupportMemOp == false) - { - return false; - } - - // Check tile size - if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) - { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - if constexpr(MWaves > 0 && NWaves > 0) - { - constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); - if constexpr(WaveSize == get_warp_size()) - { - return true; - } - else - { - return false; - } - } - else - { - return false; - } - } - else - { - return false; - } + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation>(); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { - if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0) - { - return false; - } - else - { - if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) || - (NPerBlock % (NXdlPerWave * NPerXdl) != 0)) - { - return false; - } - else - { - if(BlockwiseGemmPipe::WaveSize != get_warp_size()) - { - return false; - } - } - } + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index b99113ef16..78546c4f99 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -35,17 +35,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -63,21 +66,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -885,6 +891,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index a9c7556130..36141bc96f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -35,19 +35,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -65,23 +67,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1049,6 +1053,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 8ca8b7a2b9..35c8c6c3b4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -34,19 +34,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_as_grid, - karg.p_bs_grid, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -64,23 +67,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_as_grid, - karg.p_bs_grid, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1112,6 +1118,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 676da3e925..5b19ff8542 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,25 +73,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1036,6 +1042,47 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } + + using MfmaInst = MfmaSelector; + + constexpr index_t KPerThread = + KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops()); + if constexpr(KPerThread % KPack != 0) + { + static_assert(0); + return false; + } + + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -1043,6 +1090,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index be3c6ebb35..8119cace3b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -443,7 +446,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __host__ __device__ static constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) { - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWaves = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); } @@ -984,6 +988,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index dfcc20b3c2..2e95ec0d52 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,23 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -895,6 +901,46 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } + + using MfmaInst = MfmaSelector; + + constexpr index_t KPerThread = + KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops()); + if constexpr(KPerThread % KPack != 0) + { + return false; + } + + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -902,6 +948,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index d832bef2da..bf7ae1c6e8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -39,21 +39,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -72,23 +75,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -860,6 +866,8 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index cb9c354701..1a356c372d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -39,20 +39,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -71,24 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1113,6 +1117,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 3ac9845b66..3d2ef9b6c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -39,20 +39,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -71,24 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1027,6 +1031,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index e90239b70a..e4152e0427 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,28 +57,31 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - // TODO ANT: separate into MMA + Epilogue - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_bias_grid, - p_c0_add_grid, - p_c0_gamma_grid, - p_c0_beta_grid, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c0_grid_desc_nblock_nperblock, - block_2_ctile_map); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + } // TODO ANT: Run layernorm epilogue here #else ignore = p_a_grid; @@ -243,6 +246,22 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_lds_workspace_size * sizeof(FloatReduceAcc)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 50363d832e..67f18de12f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -77,6 +77,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle static constexpr auto BK1 = Number{}; static constexpr auto AK0PerBlock = Number{}; static constexpr auto BK0PerBlock = Number{}; + static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); struct TileLoadThreadGroup { @@ -171,6 +172,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle c_block_size * sizeof(EDataTypeShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + EDataType, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 344c7d6528..abb8c52e0f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -166,20 +166,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -209,8 +213,8 @@ template + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + FloatC, + CGlobalMemoryDataOperation>(); + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -530,8 +549,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); @@ -604,14 +623,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); } using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -751,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, MfmaSelector::selected_mfma.k_per_blk); @@ -764,8 +783,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, KPack>{}; @@ -807,8 +826,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -834,8 +853,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight static_assert(M1 == MWave, ""); static_assert(N1 == NWave, ""); - static_assert(M2 * M3 * M4 == MPerXDL, ""); - static_assert(N2 == NPerXDL, ""); + static_assert(M2 * M3 * M4 == MPerXdl, ""); + static_assert(N2 == NPerXdl, ""); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_block_desc_mblock_mperblock_nblock_nperblock, @@ -845,11 +864,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -920,9 +939,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -941,11 +960,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 0dbdac85bf..9e524c5a23 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -44,20 +44,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -86,8 +90,8 @@ template static constexpr auto K1 = Number{}; - static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); + static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); static constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; using ThisThreadBlock = ThisThreadBlock; @@ -164,6 +168,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 return (a_block_space_size_aligned) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -175,8 +194,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -242,7 +261,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 make_tuple(make_unmerge_transform( make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), make_unmerge_transform(make_tuple( - N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), + N / (NXdlPerWave * NWaves * NPerXdl), NXdlPerWave, NWaves, NPerXdl)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); @@ -264,7 +283,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 __device__ static auto GetWaveKNIdx(const index_t thread_id) { constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))), + make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXdl))), make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{})); @@ -311,8 +330,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>; @@ -510,8 +529,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>{}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index a13ce732e6..c5f60a7413 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,12 +39,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op) { #if defined(__gfx9__) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + } #else ignore = karg; ignore = b2c_map; @@ -70,8 +73,8 @@ template {}, + Number{}, I1, - Number{})); + Number{})); } // return block_id to C matrix tile idx (m0, n0, k_split) mapping @@ -711,8 +718,8 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1, @@ -766,8 +773,8 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -799,11 +806,11 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -874,9 +881,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -895,11 +902,11 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index 6aa61fcd38..a040409a6d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,23 +37,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_workspace, - M, - N, - K, - StrideA, - StrideB, - StrideC, - block_mapping, - static_cast(p_shared)); + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC, + block_mapping, + static_cast(p_shared)); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -83,8 +86,8 @@ template ::value) @@ -380,27 +387,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk __host__ __device__ static constexpr auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); } __host__ __device__ static constexpr auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(Number{}, - Number{}, + Number{}, Number{}, - Number{})); + Number{})); } __host__ __device__ static constexpr auto GetClusterLengthReduction() @@ -490,8 +497,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1>{}; @@ -829,8 +836,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -871,12 +878,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform( make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, @@ -948,9 +955,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk CElementwiseOperation, // ElementwiseOperation, // InMemoryDataOperationEnum::Set, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, @@ -977,9 +984,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk CElementwiseOperation, // ElementwiseOperation, // InMemoryDataOperationEnum::Set, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, @@ -1000,11 +1007,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index ae9a8af813..aa7ce1f5b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,16 +38,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -68,25 +72,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto a_grid_desc_k0_m_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); - const auto b_grid_desc_k0_n_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( - karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); - const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( - karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); - GridwiseGemm::template Run(karg.p_a_grid, - karg.p_b_grid, - karg.p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -103,8 +111,8 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && K1Value <= 4) || + (is_same::value && K1Value <= 8) || + ((is_same::value || is_same::value) && K1Value < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto K1 = Number:: + selected_mfma.k_per_blk)>{}; using ThisThreadBlock = ThisThreadBlock; @@ -314,6 +332,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -323,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -356,8 +390,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); // check gridwise gemm pipeline @@ -421,8 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>; @@ -566,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1, @@ -704,8 +738,8 @@ template >::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index f779e63752..a9a463e2c1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,23 +42,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); +#ifdefined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - __shared__ FloatAB p_shared_block[shared_block_size]; + __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -171,6 +175,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 595a597318..d8f22b682d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -36,13 +36,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + } #else ignore = karg; ignore = b2c_map; @@ -68,8 +71,8 @@ template {}, + Number{}, I1, - Number{})); + Number{})); } // return block_id to C matrix tile idx (m0, n0, k_split) mapping @@ -848,8 +855,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1, @@ -899,8 +906,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -932,11 +939,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -1007,9 +1014,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -1028,11 +1035,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8822778b52..7d5a8da60f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,21 +46,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -230,6 +234,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index c3bbece33c..7c559d1f85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -49,23 +49,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -235,6 +239,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 c_block_size * sizeof(FloatC)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -298,7 +318,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 const auto NBlock = N / NPerBlock; constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWave = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 2e288efee2..83f8773a08 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,25 +53,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_c1_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -244,6 +248,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 c_block_size * sizeof(FloatC)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -307,7 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 const auto NBlock = N / NPerBlock; constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWave = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index b0a606cf38..b9c0d671db 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -45,24 +45,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -80,26 +83,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -940,6 +946,8 @@ struct GridwiseMoeGemm c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index a8b759da38..5f5e24fb9f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -45,26 +45,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -82,28 +85,31 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -956,6 +962,8 @@ struct GridwiseMoeGemmBlockScale c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 34fcf0e935..9066decc0a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -49,6 +49,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -68,6 +70,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -87,27 +90,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1167,6 +1173,8 @@ struct GridwiseMoeGemmMX } } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 3a7b35683d..6854e64092 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -48,25 +48,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -86,6 +89,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -107,6 +112,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1099,6 +1105,8 @@ struct GridwiseMoeGemmMXBNS } } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 3c4f7a24c7..c367079aab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -48,25 +48,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -85,27 +88,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1059,6 +1065,8 @@ struct GridwiseMoeGemmMX_BPreshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) {