diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp index 6029ab0c7d..f233794ec1 100644 --- a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" #include "ck/host/stringutils.hpp" @@ -76,28 +76,28 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| // | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| // | | | | | | | | | | | Wave| Wave| Wave| | - { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, - { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, - { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, - { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, - { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, - { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, - { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, - { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, // Padded fallback kernel - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, - { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1}, // Irregular k - { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, - { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, - { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, - { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, - { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, - { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + { 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1}, + { 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1}, + { 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1}, + { 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1}, + { 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1}, + { 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1}, // clang-format on }; @@ -200,28 +200,28 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // _MBlock_MWaveMPerXdl| ScalarPerVector // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1,16>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1,16>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1,16>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1,16>, 4}, + { S<1, 32, 1, 8>, 4}, // Padded fallback kernel - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // Irregular k - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // clang-format on }; diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fe556615e0..b6cae670fe 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/stringutils.hpp" @@ -81,16 +81,16 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| // | | | | | | | | Wave| Wave| Stage| // | | | | | | | | | | | - { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1}, - { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1}, - { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, - { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, + { 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1}, + { 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1}, + { 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1}, + { 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1}, // Irregular tile - { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, + { 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1}, // clang-format on }; @@ -194,14 +194,14 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // _MBlock_MWaveMPerXdl| ScalarPerVector // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 4>, 8}, - { S<1, 16, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 4>, 4}, + { S<1, 16, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, + { S<1, 32, 1, 8>, 4}, // Irregular tile { S<1, 16, 1, 4>, 1}, // clang-format on diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index a2f322c50f..26988255c3 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include @@ -55,12 +55,12 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr // Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| // | | | | | | | | Wave| Wave| Stage| // | | | | | | | | | | | - { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, - { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, - { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, - { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, - { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, - { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + { 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1}, + { 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1}, + { 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1}, + { 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1}, + { 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1}, + { 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1} // clang-format on }; @@ -116,11 +116,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr // _NBlock_NWaveNPerXdl| _NWaveNPerXdl // | { S<1, 16, 1, 4>, 1}, - { S<1, 32, 1, 8>, 8}, - { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 16>, 4}, + { S<1, 32, 1, 8>, 4}, { S<1, 16, 1, 4>, 1}, - { S<1, 32, 1, 8>, 8}, - { S<1, 16, 1, 8>, 8} + { S<1, 32, 1, 8>, 4}, + { S<1, 16, 1, 8>, 4} // clang-format on }; @@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}( constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); // GridwiseGemm - using GridwiseGemm = DeviceConv::GridwiseGemm; - + using GridwiseGemm = ck::conditional_t; static constexpr auto I0 = ck::Number<0>{}; ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index c15a9fd7d3..4cfe7a117f 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/utils.hpp" @@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y) const std::unordered_set& get_xdlop_archs() { - static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx942"}; + static std::unordered_set supported_archs{ + "gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"}; return supported_archs; } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 9902caab04..15365aadf1 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index 205283e7aa..d7ff793cb8 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index 2b83af2432..1129dbc015 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index fbe27e9c8b..5696178f68 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -160,9 +160,10 @@ struct Epilogue Epilogue{1.0f, 1.0f}); out_host.SetZero(); ref_invoker.Run(ref_argument);**/ - + int i = 0; for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) { + std::cout << "Testing solution " << std::to_string(++i) << std::endl; // substitute instance values into the template auto src = ck::host::InterpolateString( conv_compile_check, diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 2bc5a4414e..2e949bb1df 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -75,6 +75,34 @@ inline bool is_xdl_supported() ; } +template +inline bool is_xdl_wmma_supported() +{ + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") + { + return true; + } +#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE) + else if(is_gfx12_supported() || is_gfx11_supported()) + { + if constexpr((MPerXDL != 16) || (NPerXDL != 16)) + { + return false; + } + if constexpr(sizeof(ADataType) > 2 || sizeof(BDataType) > 2) + { + return false; + } + return true; + } +#endif + else + { + return false; + } +} + inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 613886453b..b729271680 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -108,10 +108,8 @@ struct BlockwiseGemmXdlops_pipeline_v4 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static_assert(MWaves > 0); - static_assert(NWaves > 0); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp index c553a57672..55e856b641 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -49,10 +49,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static_assert(MWaves > 0); - static_assert(NWaves > 0); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index c946abb77d..e7ce7cbcf5 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -11,6 +11,7 @@ #include "ck/stream_config.hpp" #endif +#include "ck/utility/get_id.hpp" namespace ck { namespace tensor_operation { @@ -46,6 +47,151 @@ namespace device { #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL #endif +template +static constexpr auto GetNXdlPerWave2() +{ + constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32; + constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_); + static_assert(MWaves > 0); + + constexpr index_t NWaves = Waves / MWaves; + if constexpr(NWaves == 0) + { + return 0; + } + else + { + if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0) + { + return NPerBlock_ / (NWaves * NPerXDL_); + } + else + { + return 0; + } + } +} + +#define GET_NXDL_PER_WAVE_IMPL \ + template \ + static constexpr auto GetNXdlPerWave() \ + { \ + return GetNXdlPerWave2(); \ + } + +#define INVOKER_RUN_IMPL \ + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ + { \ + if(get_warp_size() == 64) \ + { \ + if constexpr(NXdlPerWave64 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + else \ + { \ + if constexpr(NXdlPerWave32 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + return 0; \ + } + +#define INVOKER_RUN3_IMPL \ + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ + { \ + if(get_warp_size() == 64) \ + { \ + if constexpr(NXdlPerWave64 > 0) \ + { \ + return RunImp(arg, stream_config); \ + } \ + } \ + else \ + { \ + if constexpr(NXdlPerWave32 > 0) \ + { \ + return RunImp( \ + reinterpret_cast(arg), \ + stream_config); \ + } \ + } \ + return 0; \ + } + +template +__device__ static bool constexpr IsValidGemmCompilationParameter() +{ +#if defined(__gfx11__) || defined(__gfx12__) + if constexpr(MPerXdl != 16 || NPerXdl != 16) + { + return false; + } +#endif + +#if defined(__gfx11__) + constexpr bool SupportMemOp = CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set; +#else + constexpr bool SupportMemOp = + sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set); +#endif + if constexpr(SupportMemOp == false) + { + return false; + } + + if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + if constexpr(MWaves > 0 && NWaves > 0) + { + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + return WaveSize == get_warp_size(); + } + } + return false; +} + +#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \ + template \ + __device__ static bool constexpr IsValidCompilationParameter() \ + { \ + return ck::tensor_operation::device::IsValidGemmCompilationParameter< \ + BlockSize, \ + MPerBlock, \ + NPerBlock, \ + MPerXdl, \ + NPerXdl, \ + MXdlPerWave, \ + NXdlPerWave, \ + CDataType_, \ + CGlobalMemoryDataOperation_>(); \ + } + #ifndef CK_CODE_GEN_RTC struct BaseArgument { diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c71153768d..157e475267 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -94,79 +94,83 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - if constexpr(isMultiA || isMultiB) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - AsPointer p_as_grid_grp; - BsPointer p_bs_grid_grp; + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + DsPointer p_ds_grid_grp; - static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); - static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - GridwiseGemm::template Run( - p_as_grid_grp, - p_bs_grid_grp, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); - } - else - { - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run( - p_as_grid + a_batch_offset, - p_bs_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } } #else ignore = p_as_grid; @@ -353,6 +357,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ComputeDataType> { using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -470,10 +477,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - using GridwiseGemm = ck::conditional_t< + template + using GridwiseGemmBase = ck::conditional_t< isMultiA || isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle, GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; @@ -481,31 +491,74 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< - decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>; using BGridPointer = remove_cvref_t< - decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument { + template + __host__ __device__ void init_ds_e_grid_desc() + { + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } __device__ __host__ Argument( APointers p_as, BPointers p_bs, @@ -549,12 +602,12 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_batch_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -655,43 +708,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; // populate desc for Ds/E - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); - - if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave64 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + init_ds_e_grid_desc(); } } else { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave32 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + init_ds_e_grid_desc(); } } } @@ -700,7 +728,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // pointers (tuple if multi AB, pointer if no) AGridPointer p_as_grid_; BGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -746,7 +774,31 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ck::Array input_left_pads_; ck::Array input_right_pads_; }; - + template + static __device__ __host__ bool check_gemm_validity(const Argument& arg) + { + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } static __device__ __host__ bool IsSupportedArgument(const Argument& arg) { namespace ctc = tensor_layout::convolution; @@ -898,27 +950,21 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // check Gridwise GEMM - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - // Genarate tuples with the same descriptors - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); - return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave64 > 0) + { + return check_gemm_validity(arg); + } } else { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave32 > 0) + { + return check_gemm_validity(arg); + } } + return false; } static __device__ __host__ auto MakeArgument( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index f59ea3efde..be09d1b505 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,44 +56,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - FloatDsPointer p_ds_grid_grp; + FloatDsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -214,6 +218,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -546,7 +554,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -567,7 +576,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -589,28 +598,49 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -642,12 +672,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -677,19 +707,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle }); // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } // for sanity check of vector memory access @@ -719,7 +749,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -767,7 +797,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -836,6 +867,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle return launch_kernel(integral_constant{}); } } + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, @@ -847,16 +879,35 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 8a8cf54e42..3d232b50bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -74,34 +76,38 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -172,6 +178,10 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -334,7 +344,8 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -359,7 +370,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t{}, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -499,7 +501,9 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -511,7 +515,9 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index b23d864f5c..0ddcd63b2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,36 +59,40 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -185,6 +189,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -346,7 +354,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm + using GridwiseGemmBase = GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -373,7 +382,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -440,8 +451,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { throw std::runtime_error("wrong! unsupported argument"); } - + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; @@ -551,7 +552,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -587,11 +605,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm()) { return false; } - // Note: we need raw lengths since threadwise copy can not handle vector load when part of // vector is out of bounds const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; @@ -617,11 +634,29 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 1f8c6b1508..9aff562744 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -82,45 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + DsPointer p_ds_grid_grp; - DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -197,6 +200,10 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -326,7 +333,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -347,7 +355,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -420,13 +448,13 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -445,19 +473,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } } @@ -474,7 +502,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -577,6 +606,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index ea5668d765..bb4af1a8df 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,7 @@ template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; - }); + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); - static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; - }); + static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); + p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; + }); - GridwiseGemm::template Run(p_a0_grid + a_batch_offset, - p_b0_grid + b_batch_offset, - p_d0s_grid, - p_b1_grid + b1_batch_offset, - p_d1s_grid, - p_e1_grid + c_batch_offset, - p_shared, - a0_element_op, - b0_element_op, - cde0_element_op, - b1_element_op, - cde1_element_op, - a0_grid_desc_ak0_m_ak1, - b0_grid_desc_bk0_n_bk1, - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - b1_grid_desc_bk0_n_bk1, - d1s_grid_desc_mblock_mperblock_nblock_nperblock, - e1_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_e1tile_map); + GridwiseGemm::template Run( + p_a0_grid + a_batch_offset, + p_b0_grid + b_batch_offset, + p_d0s_grid, + p_b1_grid + b1_batch_offset, + p_d1s_grid, + p_e1_grid + c_batch_offset, + p_shared, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op, + a0_grid_desc_ak0_m_ak1, + b0_grid_desc_bk0_n_bk1, + d0s_griddesc_m_n, + b1_grid_desc_bk0_n_bk1, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_e1tile_map); + } #else ignore = p_a0_grid; ignore = p_b0_grid; @@ -129,7 +133,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = cde1_element_op; ignore = a0_grid_desc_ak0_m_ak1; ignore = b0_grid_desc_bk0_n_bk1; - ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = d0s_griddesc_m_n; ignore = b1_grid_desc_bk0_n_bk1; ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock; @@ -231,6 +235,21 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; + static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2(); + static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2(); + static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); @@ -443,7 +462,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, // TODO: distinguish A/B datatype Acc0DataType, D0sDataType, @@ -475,7 +495,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle B1K1, Gemm0MPerXdl, Gemm0NPerXdl, - Gemm0MXdlPerWave, + Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, @@ -509,15 +529,17 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using A0GridDesc_AK0_M_AK1 = - remove_cvref_t; using B0GridDesc_BK0_N_BK1 = - remove_cvref_t; using B1GridDesc_BK0_N_BK1 = - remove_cvref_t; // Argument @@ -565,15 +587,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle e1_grid_desc_m_n_{ DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideE1)}, a0_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, b0_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + GridwiseGemm64::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, b1_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, - d1s_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, + block_2_e1tile_map_{GridwiseGemm64::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, a0_element_op_{a0_element_op}, b0_element_op_{b0_element_op}, cde0_element_op_{cde0_element_op}, @@ -597,18 +616,6 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; - std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" - << std::endl; std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } @@ -636,34 +643,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle d1s_grid_desc_m_n_(i) = DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideD1s[i]); }); - - if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_, - b0_grid_desc_n_k_, - b1_grid_desc_n_k_, - e1_grid_desc_m_n_, - block_2_e1tile_map_)) - { - e1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n_); - - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = - GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( - d0s_grid_desc_m_n_); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - d1s_grid_desc_m_n_); - } } // private: // pointers const A0DataType* p_a0_grid_; const B0DataType* p_b0_grid_; - typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + typename GridwiseGemm64::D0sGridPointer p_d0s_grid_; const B1DataType* p_b1_grid_; - typename GridwiseGemm::D1sGridPointer p_d1s_grid_; + typename GridwiseGemm64::D1sGridPointer p_d1s_grid_; E1DataType* p_e1_grid_; // tensor descriptors for problem definiton @@ -677,16 +665,10 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_; B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e1-tile map - typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_; + typename GridwiseGemm64::DefaultBlock2E1TileMap block_2_e1tile_map_; // element-wise op A0ElementwiseOperation a0_element_op_; @@ -705,7 +687,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, arg.b0_grid_desc_n_k_, @@ -716,6 +699,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e1_grid_desc_m_n_); + + auto d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.d1s_grid_desc_m_n_); + const index_t grid_size = arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_; @@ -736,7 +727,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle CDE1ElementwiseOperation, DeviceOp::A0GridDesc_AK0_M_AK1, DeviceOp::B0GridDesc_BK0_N_BK1, - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + DeviceOp::D0sGridDesc_M_N, DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, @@ -762,10 +753,10 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle arg.cde1_element_op_, arg.a0_grid_desc_ak0_m_ak1_, arg.b0_grid_desc_bk0_n_bk1_, - arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + arg.d0s_grid_desc_m_n_, arg.b1_grid_desc_bk0_n_bk1_, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_e1tile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_); @@ -783,6 +774,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle } } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(Gemm0MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(Gemm0MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -813,7 +823,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -836,11 +846,29 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return false; } - return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, - arg.b0_grid_desc_n_k_, - arg.b1_grid_desc_n_k_, - arg.e1_grid_desc_m_n_, - arg.block_2_e1tile_map_); + if(get_warp_size() == 64) + { + if constexpr(Gemm0MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + } + else + { + if constexpr(Gemm0MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 64d5fbd509..7c8c93cd91 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -37,35 +37,38 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - // populate pointer, desc for Ds - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - // D pointer - karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; - }); + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -83,39 +86,42 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - // populate pointer, desc for Ds - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - // D pointer - karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; - }); + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid + c_batch_offset, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -185,12 +191,17 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; using CDataType_ = CDataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -214,7 +225,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -241,6 +252,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm = GridwiseGemm64; struct ComputePtrOffsetOfStridedBatch { @@ -270,7 +284,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { std::array ds_offset_; - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + static_for<0, NumDTensor, 1>{}([&](auto i) { ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; }); @@ -289,32 +303,33 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideC_; }; - struct Argument : public GridwiseGemm::Argument + template + struct ArgumentBase : public GridwiseGemm::Argument { index_t Batch; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - Argument() = default; - Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - std::array p_ds_grid_, - CDataType* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_, - index_t BatchStrideA_, - index_t BatchStrideB_, - const std::array& BatchStrideDs_, - index_t BatchStrideE_, - index_t Batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - index_t KBatch_) + ArgumentBase() = default; + ArgumentBase(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + index_t KBatch_) : GridwiseGemm::Argument{p_a_grid_, p_b_grid_, p_ds_grid_, @@ -336,6 +351,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { } }; + using Argument = ArgumentBase; struct ActiveWorkgroupsPerCU { @@ -359,31 +375,69 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 } }(); - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + if(get_warp_size() == 64) { - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &max_occupancy, - kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< - GridwiseGemm, - Argument, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>, - BlockSize, - dynamic_smem_size)); + if constexpr(NXdlPerWave64 > 0) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm64, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm64, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + } } else { - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &max_occupancy, - kernel_batched_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - Argument, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>, - BlockSize, - dynamic_smem_size)); + if constexpr(NXdlPerWave32 > 0) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm32, + ArgumentBase, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm32, + ArgumentBase, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + } } max_occupancy_ = std::max(1, max_occupancy); @@ -394,14 +448,16 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + using BatchGemmArgument = ArgumentBase; if(stream_config.log_level_ > 0) { arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg)) + if(!GridwiseGemm::CheckValidity(reinterpret_cast(arg))) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } @@ -423,7 +479,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 std::array DsSize; - Argument arg_ = arg; + BatchGemmArgument arg_ = reinterpret_cast(arg); const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -442,8 +498,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -480,13 +540,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 stream_config.stream_id_)); }; - ave_time = launch_and_time_kernel_with_preprocess(stream_config, + BatchGemmArgument arg_ = reinterpret_cast(arg); + ave_time = launch_and_time_kernel_with_preprocess(stream_config, clear_workspace, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - arg); + arg_); } }; @@ -515,7 +576,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -525,7 +586,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -541,7 +602,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -553,7 +614,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -567,7 +628,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -583,7 +644,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -599,7 +660,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -615,7 +676,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -630,7 +691,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -646,7 +707,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -661,7 +722,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -673,7 +734,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -687,7 +748,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -703,7 +764,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -719,7 +780,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -735,7 +796,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -750,7 +811,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -766,7 +827,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -785,7 +846,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -796,7 +857,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -810,7 +871,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -821,7 +882,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -838,7 +899,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -849,7 +910,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -863,7 +924,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -874,7 +935,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -893,7 +954,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -903,7 +964,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, - Argument, + BatchGemmArgument, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -915,6 +976,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -931,11 +994,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -948,8 +1014,22 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -1093,7 +1173,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index ffebad253b..d2b77a5901 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -60,41 +60,46 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { - const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); - p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; - }); + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; + }); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -172,6 +177,9 @@ template { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -517,7 +525,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO }; // GridwiseGemm - using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -546,7 +555,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -571,6 +580,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -600,32 +611,18 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, compute_base_ptr_of_batch_{ type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -638,12 +635,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -656,8 +649,24 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { @@ -680,15 +689,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl; } } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; @@ -716,28 +716,27 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO typename GridwiseGemm::DefaultBlock2CTileMap, true>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.Batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } else { @@ -759,33 +758,34 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO typename GridwiseGemm::DefaultBlock2CTileMap, false>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.Batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -802,15 +802,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index d835bb6c61..745cf2b722 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; - }); + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_d0s_grid, - p_shared, - a_element_op, - b_element_op, - c0de_element_op, - b1_element_op, - c1de_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c1_grid_desc_mblock_mperblock_nblock_nperblock, - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - block_2_ctile_map, - c0_matrix_mask); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_d0s_grid, + p_shared, + a_element_op, + b_element_op, + c0de_element_op, + b1_element_op, + c1de_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + d0s_griddesc_m_n, + block_2_ctile_map, + c0_matrix_mask); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -122,7 +126,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = b_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1; ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = d0s_griddesc_m_n; ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; @@ -218,6 +222,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle C1DEElementwiseOperation, MaskingSpec> { + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -377,7 +386,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -406,7 +416,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -441,6 +451,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument // FIXME: constness @@ -485,6 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_gs_ms_gemm1ns_strides)}, + d0s_grid_desc_m_n_{DeviceOp::MakeD0sGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides)}, a_grid_desc_g_m_k_{ Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, b_grid_desc_g_n_k_{ @@ -495,9 +509,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle c_gs_ms_gemm1ns_strides)}, d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N( acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}, - c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c0de_element_op_{c0de_element_op}, @@ -538,23 +550,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle d0s_nl_ns_lengths_strides_[i].push_back( acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); }); - - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - b1_grid_desc_bk0_n_bk1_, - c1_grid_desc_m_n_, - block_2_ctile_map_)) - { - c1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c1_grid_desc_m_n_); - - D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N( - acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}; - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = - GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( - d0s_grid_desc_m_n); - } } void Print() const @@ -578,26 +573,22 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const BDataType* p_b_grid_; const B1DataType* p_b1_grid_; CDataType* p_c_grid_; - typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + typename GridwiseGemm64::D0sGridPointer p_d0s_grid_; // tensor descriptor AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; C1GridDesc_M_N c1_grid_desc_m_n_; + D0sGridDesc_M_N d0s_grid_desc_m_n_; AGridDesc_G_M_K a_grid_desc_g_m_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_; C1GridDesc_G_M_N c1_grid_desc_g_m_n_; D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; - typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; - // block-to-c-tile map - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; // element-wise op AElementwiseOperation a_element_op_; @@ -626,12 +617,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { throw std::runtime_error("wrong! unsupported argument"); } + auto c1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c1_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_; @@ -657,7 +652,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + D0sGridDesc_M_N, typename GridwiseGemm::DefaultBlock2CTileMap, ComputeBasePtrOfStridedBatch, C0MatrixMask, @@ -681,8 +676,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.d0s_grid_desc_m_n_, arg.block_2_ctile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_, @@ -703,6 +698,25 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -724,11 +738,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // TODO ANT: Check if tensor specialization & strides mismatch // Check if C permute dimension matches GEMM + GEMM shape @@ -792,12 +805,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.b1_grid_desc_bk0_n_bk1_, - arg.c1_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 1345d2b782..77379c8fb1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1,16 +1,18 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - #ifndef __HIPCC_RTC__ #include #include +#endif + +#include "ck/utility/common_header.hpp" +#ifndef __HIPCC_RTC__ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #endif -#include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -64,37 +66,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_b1_grid + b1_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map, - c0_matrix_mask); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + c0_matrix_mask); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -202,7 +208,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle CElementwiseOperation, MaskOutUpperTriangle> { + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -369,7 +380,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle C0MatrixMask_impl>; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -396,7 +408,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -430,6 +442,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle LoopSched, matrix_padder.PadN, MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; #ifndef __HIPCC_RTC__ // Argument @@ -466,8 +480,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle b1_grid_desc_bk0_n_bk1_{ DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, acc_element_op_{acc_element_op}, @@ -478,16 +491,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle c0_matrix_mask_{NRaw}, raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - b1_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } } // private: @@ -499,9 +502,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; AccElementwiseOperation acc_element_op_; @@ -522,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, @@ -532,7 +534,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; @@ -578,7 +582,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_, @@ -599,6 +603,25 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -707,11 +730,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle #ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // Note: we need raw lengths since threadwise copy can not handle vector load when part of // vector is out of bounds const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; @@ -719,12 +741,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_) and - IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + } + } + return false; } // polymorphic @@ -916,7 +957,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -943,7 +985,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -978,13 +1020,16 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle matrix_padder.PadN, MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; CGridDesc_M_N c_grid_desc_m_n; C0MatrixMask c0_matrix_mask; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; // element-wise op @@ -1008,30 +1053,54 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, - block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + block_2_ctile_map{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, c_grid_descriptor_mblock_mperblock_nblock_nperblock{ - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n)}, - has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop( a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, c0_matrix_mask{c.GetLength(I1)}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, b1_element_op{b1_element_op_}, c_element_op{c_element_op_}, - is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_m_n, - block_2_ctile_map) and - IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), - b_grid_desc_bk0_n_bk1.GetLength(I1), - a_grid_desc_ak0_m_ak1.GetLength(I0) * - a_grid_desc_ak0_m_ak1.GetLength(I2), - b1_grid_desc_bk0_n_bk1.GetLength(I1))} + is_valid{false} { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + is_valid = GridwiseGemm64::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1)) and + GridwiseGemm64::template IsValidCompilationParameter<>(); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + is_valid = GridwiseGemm32::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1)) and + GridwiseGemm32::template IsValidCompilationParameter<>(); + } + } } - constexpr bool IsValid() const { return is_valid; } }; @@ -1061,12 +1130,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle #ifndef __HIPCC_RTC__ assert(desc.is_valid); #endif - __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + using GridwiseGemm = conditional_t; + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; AccElementwiseOperation acc_element_op{scale}; if(desc.has_main_k_block_loop) { - Desc::GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_b1_grid, @@ -1086,7 +1158,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle } else { - Desc::GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_b1_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index d3f067f170..b8f03e742f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,36 +52,40 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto a_grid_desc_k0_m_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); - const auto b_grid_desc_k0_n_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( - karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); - const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( - karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + c_batch_offset, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = karg; #endif @@ -135,6 +139,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -172,7 +180,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -192,7 +201,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Problem = typename GridwiseGemm::Problem; + using Problem = typename GridwiseGemm64::Problem; // Argument struct Argument : public Problem, public BaseArgument @@ -255,14 +266,17 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm + float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { karg.Print(); } - if(!GridwiseGemm::CheckValidity(karg)) + typename GridwiseGemm::Problem arg( + karg.M, karg.N, karg.K, karg.StrideA, karg.StrideB, karg.StrideC); + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); @@ -293,6 +307,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm()) { return false; } - - return GridwiseGemm::CheckValidity(problem); + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return false; + } + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(problem); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(problem)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 459ebd7f35..d19810694b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -37,27 +37,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = + karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - - GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -75,31 +78,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t g_idx = blockIdx.z % karg.Batch; - const index_t k_idx = blockIdx.z / karg.Batch; + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; - const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = + karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -171,8 +177,13 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -196,7 +207,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -223,6 +234,8 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale ComputeTypeB, PermuteA, PermuteB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -277,31 +290,32 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale index_t BatchStrideScaleB_; }; - struct Argument : public GridwiseGemm::Argument + template + struct ArgumentBase : public GridwiseGemm::Argument { index_t Batch; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t StrideScaleB_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t BatchStrideScaleB_, - const BScaleDataType* p_b_scale_grid_, - index_t Batch_, - index_t KBatch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_) + ArgumentBase(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) : GridwiseGemm::Argument(p_a_grid_, p_b_grid_, p_c_grid_, @@ -323,12 +337,16 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { } }; + using Argument = ArgumentBase; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const ArgumentBase& arg, + const StreamConfig& stream_config = StreamConfig{}) { + using DeviceArgument = ArgumentBase; if(stream_config.log_level_ > 0) { arg.Print(); @@ -353,7 +371,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + DeviceArgument arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -365,7 +383,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) / BPackedSize; - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -422,7 +440,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -432,7 +450,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -448,7 +466,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -460,7 +478,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -474,7 +492,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -490,7 +508,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -506,7 +524,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -522,7 +540,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -537,7 +555,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -553,7 +571,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -568,7 +586,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -580,7 +598,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -594,7 +612,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -610,7 +628,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -626,7 +644,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -642,7 +660,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -657,7 +675,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -673,7 +691,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -692,7 +710,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -703,7 +721,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -717,7 +735,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -728,7 +746,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -745,7 +763,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -756,7 +774,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -770,7 +788,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -781,7 +799,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -800,7 +818,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -810,7 +828,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< GridwiseGemm, - Argument, + DeviceArgument, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -822,6 +840,27 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using Argument32 = ArgumentBase; + return RunImp(reinterpret_cast(arg), + stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -838,11 +877,14 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -855,8 +897,22 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using Argument32 = ArgumentBase; + return GridwiseGemm32::CheckValidity(reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -1003,7 +1059,7 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 4934993693..4aa6c18d04 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,6 +75,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle : public DeviceCGemm { using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -118,7 +121,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle } // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, @@ -142,7 +146,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -164,13 +168,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1})); // Argument - struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem + struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm64::Problem { - using Problem = typename GridwiseGemm::Problem; + using Problem = typename GridwiseGemm64::Problem; Argument(const ADataType* p_a_grid_real_, const ADataType* p_a_grid_imag_, @@ -221,14 +227,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg)) + typename GridwiseGemm::Problem problem( + arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC); + if(!GridwiseGemm::CheckValidity(problem)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } @@ -317,7 +326,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_real, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -327,7 +336,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_imag, arg.p_aux_2_grid, - arg); + problem); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( @@ -352,7 +361,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_imag, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -362,7 +371,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_real, arg.p_aux_2_grid, - arg); + problem); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( @@ -395,7 +404,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_real, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -405,7 +414,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_imag, arg.p_aux_2_grid, - arg); + problem); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( @@ -430,7 +439,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_real, arg.p_b_grid_imag, arg.p_aux_grid, - arg); + problem); ave_time += launch_and_time_kernel(stream_config, kernel, @@ -440,7 +449,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.p_a_grid_imag, arg.p_b_grid_real, arg.p_aux_2_grid, - arg); + problem); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( @@ -461,6 +470,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -477,12 +488,27 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Problem problem( + arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC); + return GridwiseGemm32::CheckValidity(problem); + } + } + return false; } // polymorphic @@ -587,8 +613,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC) { - const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( - M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC); + const auto c_grid_desc_m_n = + GridwiseGemm64::MakeCGridDescriptor_M_N(M, + GridwiseGemm64::CalculateMPadded(M), + N, + GridwiseGemm64::CalculateNPadded(N), + StrideC); return c_grid_desc_m_n.GetElementSpaceSize(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 27f0a7af7c..cf5d5bd64e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,22 +55,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_as_grid; ignore = p_bs_grid; @@ -160,6 +164,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -172,7 +180,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle using ComputeDataType = EDataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, @@ -194,7 +203,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -217,6 +226,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto matrix_padder = ck::tensor_operation::device::MatrixPadder{ @@ -385,21 +396,21 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // desc for blockwise copy using AsGridDesc_AK0_M_AK1 = - remove_cvref_t; using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -427,11 +438,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle bs_grid_desc_n_k_{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, - as_grid_desc_ak0_m_ak1_{}, - bs_grid_desc_bk0_n_bk1_{}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -475,28 +482,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); }); - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, - bs_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - as_grid_desc_ak0_m_ak1_ = - GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); - - bs_grid_desc_bk0_n_bk1_ = - GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } - // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { @@ -521,9 +506,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } // pointers - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::AsGridPointer p_as_grid_; + typename GridwiseGemm64::BsGridPointer p_bs_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -532,13 +517,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; - // tensor descriptors for block/thread-wise copy - AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; - BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -564,7 +542,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, arg.bs_grid_desc_n_k_, @@ -574,7 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto as_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(arg.as_grid_desc_m_k_); + auto bs_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(arg.bs_grid_desc_n_k_); + + auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_); + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -609,10 +600,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.as_grid_desc_ak0_m_ak1_, - arg.bs_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_etile_map_); }; @@ -628,6 +619,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -638,11 +631,12 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + using A0DataType = remove_cvref_t>; + using B0DataType = remove_cvref_t>; + if(!ck::is_xdl_wmma_supported()) { return false; } - // check vector load/store { bool valid_as_access = true; @@ -713,11 +707,30 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle } } - return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() > 0) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 615566a555..2a569a49e9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,23 +53,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -163,6 +166,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -314,7 +321,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -335,7 +343,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -357,24 +365,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -403,12 +413,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -425,22 +433,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); }); - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - } - // for sanity check of vector memory access tie(a_continous_dim_, a_max_read_elems_) = CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); @@ -471,7 +463,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -512,7 +504,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -523,7 +516,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle throw std::runtime_error( "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -562,8 +561,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.cde_element_op_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_etile_map_); }; @@ -577,6 +576,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -587,21 +588,39 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 23440e24f6..ff652ebefb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -72,6 +72,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using DeviceOp = DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; @@ -281,7 +285,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -300,7 +305,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ NPerXdl, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -329,8 +334,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmAtomicAddBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -349,7 +357,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ NPerXdl, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -378,6 +386,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; + // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); @@ -506,7 +517,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -635,6 +648,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -650,11 +665,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // vector load A/B matrix from global memory if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index d4f89b3e09..2be31cabed 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -35,8 +35,8 @@ template (); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = WeiDataType; using CDataType = InDataType; @@ -374,7 +378,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -386,11 +391,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -410,6 +415,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -512,7 +519,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -603,6 +611,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -618,11 +628,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -651,14 +660,30 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i])) + bool valid = false; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + if(!valid) + return false; } return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 5d039427d6..94a0dc8c84 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -68,6 +68,10 @@ struct using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -467,7 +471,8 @@ struct using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -488,7 +493,7 @@ struct NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -509,6 +514,8 @@ struct CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -541,9 +548,6 @@ struct c_grid_desc_m_n_{}, c0_grid_desc_m_n_{}, c1_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, @@ -578,27 +582,6 @@ struct c1_grid_desc_m_n_ = descs[I4]; block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c1_grid_desc_m_n_); - } } // private: @@ -612,15 +595,7 @@ struct CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -643,7 +618,8 @@ struct { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -700,6 +676,20 @@ struct float ave_time = 0; + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); + + auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c0_grid_desc_m_n_); + + auto c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c1_grid_desc_m_n_); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r3< @@ -736,9 +726,9 @@ struct arg.p_c1_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -780,9 +770,9 @@ struct arg.p_c1_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -792,6 +782,8 @@ struct return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -807,11 +799,10 @@ struct static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -851,10 +842,27 @@ struct } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 242f5cd673..3a3b29a168 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -69,6 +69,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -448,7 +452,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using C0GridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -468,7 +473,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -489,6 +494,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -520,8 +527,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c0_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -556,23 +561,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X c_grid_desc_m_n_ = descs[I2]; c0_grid_desc_m_n_ = descs[I3]; block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c0_grid_desc_m_n_); - } + GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } const ADataType* p_a_grid_; @@ -583,13 +572,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm:: - C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; InElementwiseOperation in_element_op_; @@ -613,7 +596,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -667,6 +651,15 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float ave_time = 0; + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); + + auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c0_grid_desc_m_n_); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< @@ -699,8 +692,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X arg.p_c0_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -738,8 +731,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X arg.p_c0_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -749,6 +742,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -764,11 +759,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -808,10 +802,27 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 0d295a2418..3bc5f9af03 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -36,8 +36,8 @@ template < ck::index_t NPerBlock, ck::index_t K0PerBlock, ck::index_t K1, - ck::index_t MPerXdl, - ck::index_t NPerXdl, + ck::index_t MPerXDL, + ck::index_t NPerXDL, ck::index_t MXdlPerWave, ck::index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_K0_M_K1, @@ -72,6 +72,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -433,7 +437,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -451,10 +456,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W K0PerBlock * K1, K1, // AK1 K1, // BK1 - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -475,6 +480,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -501,7 +508,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, @@ -534,28 +540,13 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W c_grid_desc_m_n_ = descs[I2]; block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - } } - const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -578,8 +569,21 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + arg.c_grid_desc_m_n_); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; @@ -614,35 +618,25 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W std::cout << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" "nwavenperxdl_{ " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I0) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I1) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I2) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I3) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I4) << ", " - << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + << c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetLength(I5) << "}" << std::endl; } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); - } - const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -679,7 +673,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -713,7 +707,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, arg.in_element_op_, arg.wei_element_op_, arg.out_element_op_, @@ -723,6 +717,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -738,11 +734,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -782,10 +777,27 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index c7aa54f1d9..cecfa48408 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -69,6 +69,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -324,7 +328,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -340,7 +345,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -360,6 +365,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -430,7 +437,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -510,6 +518,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -525,11 +535,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -569,8 +578,23 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index dc8499fcf2..09dc2a03db 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -56,32 +56,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); - const long_index_t b_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -140,6 +143,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = InDataType; using BDataType = WeiDataType; using CDataType = OutDataType; @@ -263,7 +270,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, AccDataType, @@ -282,7 +290,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -302,6 +310,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 7, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); @@ -399,7 +409,9 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -507,6 +519,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -523,11 +537,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 2881036bee..d074342127 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -36,8 +36,8 @@ template (); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using ADataType = OutDataType; using BDataType = WeiDataType; using CDataType = InDataType; @@ -975,7 +979,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype AccDataType, @@ -987,11 +992,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -1011,6 +1016,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -1216,7 +1223,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -1305,6 +1313,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1320,11 +1330,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -1354,14 +1363,30 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i])) + bool valid = false; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + if(!valid) + return false; } return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp index 5bfef60af2..3929525987 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,6 +81,10 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -383,7 +387,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -417,7 +422,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -442,6 +447,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -477,11 +484,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, @@ -489,26 +492,6 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - c0_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c0_grid_desc_m_n_); - - c1_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c1_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -524,15 +507,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO C0GridDesc_M_N c0_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c0_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -546,7 +521,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, @@ -555,7 +531,20 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + auto c0_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c0_grid_desc_m_n_); + + auto c1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c1_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -607,10 +596,10 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO arg.reduce_out_element_ops_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, arg.block_2_ctile_map_); } else @@ -657,16 +646,18 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO arg.reduce_out_element_ops_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -683,15 +674,31 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp index 4d9dd192c9..db62dd340f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,6 +80,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -88,7 +92,8 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD>; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -111,7 +116,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -570,6 +579,8 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -599,7 +609,22 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -739,7 +764,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector; static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector; @@ -268,8 +272,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto matrix_padder = MatrixPadder{ - GemmMPerBlock, GemmNPerBlock, GemmKPerBlock}; + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -326,7 +330,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static_assert(is_same::value); return DeviceOp:: - MakeEHGridDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>( + MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( MRaws[i], NRaws[i], DsStride[i]); }, Number{}); @@ -363,12 +367,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // We have to separate mean var descriptor for gemm and layernorm bacause of different grid // layout(different padding) using GemmMeanVarGridDesc_M_NBlock = - decltype(MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, - 1)); + decltype(MakeMeanVarDescriptor_M_N, MPerBlock, NPerBlock>(1, 1)); using GemmCountGridDesc_M_NBlock = - decltype(MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, - 1)); + decltype(MakeCountDescriptor_M_N, MPerBlock, NPerBlock>(1, 1)); using LayernormMeanVarGridDesc_M_NBlock = decltype(MakeMeanVarDescriptor_M_N, @@ -383,7 +385,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle using GammaBetaGridDesc_N = decltype(MakeDescriptor_X(1)); using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N, 1, 1>(1, 1, 1)); - using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< + template + using GridwiseGemmWelfordBase = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -401,15 +404,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, + MPerBlock, + NPerBlock, + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -432,8 +435,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle PostShuffleScalarPerVector, LoopSched, PipelineVer>; + using GridwiseGemmWelford64 = GridwiseGemmWelfordBase; + using GridwiseGemmWelford32 = GridwiseGemmWelfordBase; - using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemmWelford64::DefaultBlock2ETileMap; using GridwiseWelfordLayernorm = GridwiseWelfordSecondHalfLayernorm2d, - GemmMPerBlock, - GemmNPerBlock>(MRaw, NRaw, StrideH)}, + DeviceOp::MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( + MRaw, NRaw, StrideH)}, layernorm_e_grid_desc_m_n_{ DeviceOp::MakeEHGridDescriptor_M_N, LayernormBlockTileSize_M_N::At(0), @@ -513,11 +517,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle LayernormBlockTileSize_M_N::At(1)>( MRaw, NRaw, StrideH)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemmWelford64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemmWelford64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, block_2_etile_map_{ - GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, + GridwiseGemmWelford64::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -525,16 +529,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle MRaw_{MRaw}, NRaw_{NRaw}, KRaw_{KRaw}, - gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)}, + gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, epsilon_{static_cast(epsilon)} { // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1. gemm_mean_var_grid_desc_m_nblock_ = - DeviceOp::MakeMeanVarDescriptor_M_N, GemmMPerBlock, 1>( + DeviceOp::MakeMeanVarDescriptor_M_N, MPerBlock, 1>( MRaw, gemm_nblock_); gemm_count_grid_desc_m_nblock_ = - DeviceOp::MakeCountDescriptor_M_N, GemmMPerBlock, 1>( + DeviceOp::MakeCountDescriptor_M_N, MPerBlock, 1>( MRaw, gemm_nblock_); layernorm_mean_var_grid_desc_m_nblock_ = @@ -558,33 +562,66 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // D desc ds_grid_desc_m_n_(i) = - DeviceOp::MakeEHGridDescriptor_M_N, - GemmMPerBlock, - GemmNPerBlock>(MRaw, NRaw, StrideDs[i]); + DeviceOp::MakeEHGridDescriptor_M_N, MPerBlock, NPerBlock>( + MRaw, NRaw, StrideDs[i]); }); // populate desc for Ds/E/mean/var/count - if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - gemm_e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + if(GridwiseGemmWelford64::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford64:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - gemm_e_grid_desc_m_n_); + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford64:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); - gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = - GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( - gemm_mean_var_grid_desc_m_nblock_); + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford64:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); - gemm_count_grid_desc_mblock_mperblock_nblock_ = - GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( - gemm_count_grid_desc_m_nblock_); + gemm_count_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford64:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + if(GridwiseGemmWelford32::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford32:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmWelford32:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); + + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford32:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); + + gemm_count_grid_desc_mblock_mperblock_nblock_ = GridwiseGemmWelford32:: + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } } } @@ -602,7 +639,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; + typename GridwiseGemmWelford64::DsGridPointer p_ds_grid_; void* p_workspace_e_grid_; void* p_workspace_mean_; void* p_workspace_var_; @@ -626,15 +663,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle EHGridDesc_M_N h_grid_desc_m_n_; // tensor descriptors for block/thread-wise copy - typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemmWelford64::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + typename GridwiseGemmWelford64::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemmWelford64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemmWelford64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + typename GridwiseGemmWelford64::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock gemm_mean_var_grid_desc_mblock_mperblock_nblock_; - typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock + typename GridwiseGemmWelford64::CountGridDescriptor_MBlock_MPerBlock_NBlock gemm_count_grid_desc_mblock_mperblock_nblock_; // block-to-e-tile map @@ -657,8 +694,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle struct Invoker : public BaseInvoker { using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0; @@ -787,7 +824,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return launch_kernel(integral_constant{}); } } - + using GridwiseGemm32 = GridwiseGemmWelford32; + using GridwiseGemm64 = GridwiseGemmWelford64; + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -857,11 +896,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // check vector load/store { using Row = ck::tensor_layout::gemm::RowMajor; @@ -944,12 +982,36 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return false; } } - - return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.gemm_e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemmWelford64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemmWelford32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return false; + } + } } // polymorphic @@ -1060,9 +1122,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle" << "<" << BlockSize << ", " - << GemmMPerBlock << ", " - << GemmNPerBlock << ", " - << GemmKPerBlock << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " << AK1 << ", " << BK1 << ", " << getGemmSpecializationString(GemmSpec) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 8ae6761769..ccc9c7f9b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,26 +60,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_rs_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - qs_element_op, - rs_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - rs_grid_desc_mblock_mperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_rs_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -185,6 +189,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { using DeviceOp = DeviceGemmMultipleDMultipleR_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumRTensor = RsDataType::Size(); @@ -282,7 +290,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -312,7 +321,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -335,15 +344,17 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -370,86 +381,52 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle p_ds_grid_{}, // FIXME p_e_grid_{static_cast(p_e_grid)}, p_rs_grid_{}, // FIXME + MRaw_(MRaw), + NRaw_(NRaw), a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - rs_grid_desc_mblock_mperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, qs_element_op_{qs_element_op}, rs_element_op_{rs_element_op} { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - e_grid_desc_m_n_, - r_grid_desc_m_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - - p_ds_grid_(i) = static_cast(p_ds_grid[i]); - - const auto d_grid_desc_m_n = - DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - d_grid_desc_m_n); - }); - - static_for<0, NumRTensor, 1>{}([&](auto i) { - using RDataType = remove_cvref_t>; - - p_rs_grid_(i) = static_cast(p_rs_grid[i]); - - rs_grid_desc_mblock_mperblock_(i) = - GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); - }); - } + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + stride_ds_[i] = StrideDs[i]; + }); + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + p_rs_grid_(i) = static_cast(p_rs_grid[i]); + }); } // private: // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; - typename GridwiseGemm::RsGridPointer p_rs_grid_; - + typename GridwiseGemm64::RsGridPointer p_rs_grid_; + index_t MRaw_; + index_t NRaw_; + std::array stride_ds_; // tensor descriptors AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; EGridDesc_M_N e_grid_desc_m_n_; RGridDesc_M r_grid_desc_m_; - // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - StaticallyIndexedArray< - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - NumDTensor> - ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different - // type from E - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - StaticallyIndexedArray - rs_grid_desc_mblock_mperblock_; - // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -466,7 +443,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -476,6 +454,31 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock = {}; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock = {}; + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(arg.MRaw_, arg.NRaw_, arg.stride_ds_[i]); + ds_grid_desc_mblock_mperblock_nblock_nperblock(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + + static_for<0, NumRTensor, 1>{}([&](auto i) { + rs_grid_desc_mblock_mperblock(i) = + GridwiseGemm64::MakeRGridDescriptor_MBlock_MPerBlock(arg.r_grid_desc_m_); + }); const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); @@ -526,9 +529,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle arg.rs_element_op_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.rs_grid_desc_mblock_mperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, arg.block_2_etile_map_); }; @@ -546,6 +549,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -556,16 +561,33 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.r_grid_desc_m_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index c7481997a9..c051a080ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -54,23 +54,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -161,6 +164,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD { using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -247,7 +253,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, @@ -268,7 +275,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; #ifndef __HIPCC_RTC__ // Argument @@ -337,12 +346,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -362,22 +369,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(MRaw, NRaw, StrideDs[i]); }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -393,7 +384,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -438,6 +427,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD()) + { + return false; + } + if(!IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_)) { return false; } - return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and - GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic @@ -735,18 +754,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_tuple()))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; - using Block2ETileMap = remove_cvref_t; // tensor descriptors for problem definiton @@ -790,21 +809,21 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD 0) + { + return GridwiseGemm64::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw) and + GridwiseGemm64::template IsValidCompilationParameter<>(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw) and + GridwiseGemm32::template IsValidCompilationParameter<>(); + } + } + return false; } constexpr index_t GetBlockSize() const { return BlockSize; } @@ -854,10 +894,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; if(desc.has_main_k_block_loop) { GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 010a23c66c..66e727faa5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -74,10 +74,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad BElementwiseOperation, CDEElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I1 = Number<1>{}; static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, DsLayout, @@ -104,7 +108,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, @@ -121,13 +125,17 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -191,6 +199,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad } } + INVOKER_RUN3_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -200,11 +210,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!ck::is_lds_direct_load_supported()) { return false; @@ -288,11 +297,29 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad } } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 0796614bd4..7e9020d796 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -82,9 +82,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK { static constexpr index_t NumDTensor = DsDataType::Size(); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -108,7 +112,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -169,7 +176,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -188,8 +195,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -624,6 +636,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -658,7 +675,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -781,7 +813,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -122,7 +126,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -140,7 +144,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -149,13 +153,17 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -180,7 +188,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -192,7 +200,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -293,6 +301,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -309,7 +319,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -328,7 +338,22 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -456,7 +481,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 4761ee2026..832267fdd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -82,10 +82,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle BElementwiseOperation, CElementwiseOperation> { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, @@ -109,7 +113,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -136,15 +140,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; int GetPreShuffleParameters() override { return NPerXDL; } // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -172,7 +179,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle std::array DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -191,8 +198,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -492,6 +504,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -508,11 +522,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -531,7 +548,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -652,7 +684,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp index c446ca59ea..4dcdff9153 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -93,9 +93,13 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle CElementwiseOperation> { static constexpr index_t NumDTensor = DsDataType::Size(); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, @@ -122,7 +126,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -140,7 +144,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -149,15 +153,19 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; int GetPreShuffleParameters() override { return NPerXDL; } // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -182,7 +190,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -194,7 +202,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - ck::utility::RotatingMemWrapper rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -322,6 +330,8 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -338,16 +348,19 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != // KPerBlock) // { // return false; // } + if(is_gfx11_supported() && arg.KBatch > 1) + { + return false; + } if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -367,7 +380,22 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -495,7 +523,7 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp index c88294edfa..14a1824508 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,6 +78,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -378,7 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -407,7 +412,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -432,6 +437,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -459,27 +466,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - reduce_grid_desc_mblock_mperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, reduce_in_element_ops_{reduce_in_element_ops}, reduce_out_element_ops_{reduce_out_element_ops} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - reduce_grid_desc_mblock_mperblock_ = - GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); - } } // private: @@ -491,11 +484,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; ReduceGridDesc_M reduce_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock - reduce_grid_desc_mblock_mperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + typename GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -508,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -536,6 +526,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + + auto reduce_grid_desc_mblock_mperblock = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(arg.reduce_grid_desc_m_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -563,26 +559,25 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio typename GridwiseGemm::DefaultBlock2CTileMap, true>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.block_2_ctile_map_); } else { @@ -603,31 +598,32 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio typename GridwiseGemm::DefaultBlock2CTileMap, false>; - elapsed_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_reduces_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.reduce_in_element_ops_, - arg.reduce_out_element_ops_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.reduce_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + arg.block_2_ctile_map_); } return elapsed_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -644,15 +640,31 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index 5188ece333..2e4abe9819 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,6 +69,10 @@ struct DeviceGemmXdl : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -76,7 +80,8 @@ struct DeviceGemmXdl : public DeviceGemm{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -96,7 +101,7 @@ struct DeviceGemmXdl : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& karg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -160,6 +169,8 @@ struct DeviceGemmXdl : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index dd41f4bca0..327b9523b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,12 +80,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, @@ -109,7 +114,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -177,6 +186,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -206,7 +216,22 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index ac2e826725..8daaafaed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,9 +69,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I1 = Number<1>{}; - using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, ck::Tuple<>, @@ -98,7 +103,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -185,6 +194,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm()) { return false; } - if(!ck::is_lds_direct_load_supported()) { return false; @@ -264,11 +274,29 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp index 3171208830..c42228369f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,8 +75,13 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, @@ -98,7 +103,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; + // // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -176,8 +186,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 rotating_mem( + auto arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, arg_.M * arg_.K * sizeof(ADataType), @@ -426,6 +436,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2()) { return false; } @@ -481,7 +498,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -489,30 +521,30 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2(p_arg)); } - + template static auto - MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t streamk_sel, - index_t Grid_size, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) + MakeArgumentImp(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) { constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock; - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - int occupancy, num_cu; + index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock; + + int occupancy = 1, num_cu = 1; const auto calculate_grid_size = [&](const auto& kernel) { hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -524,185 +556,193 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - calculate_grid_size(kernel); - } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy>; calculate_grid_size(kernel); } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; + TailNumber::One>; calculate_grid_size(kernel); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; + TailNumber::Full>; calculate_grid_size(kernel); } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_2lds; calculate_grid_size(kernel); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_2lds; calculate_grid_size(kernel); } } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - calculate_grid_size(kernel); - } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - calculate_grid_size(kernel); + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } } } else { - - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy>; calculate_grid_size(kernel); } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - - const auto kernel = kernel_gemm_xdl_cshuffle_v3; - calculate_grid_size(kernel); } } @@ -720,6 +760,62 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 0; + return MakeArgumentImp(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + a_op, + b_op, + c_op, + reduction_strategy); + } + else + { + constexpr bool IsValid = NXdlPerWave32 > 0; + return MakeArgumentImp(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + a_op, + b_op, + c_op, + reduction_strategy); + } + } static auto MakeInvoker() { return Invoker{}; } // polymorphic @@ -799,7 +895,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 { using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v2< ALayout, BLayout, CLayout, @@ -109,7 +113,7 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -174,6 +182,8 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -203,7 +212,22 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 1cb82d24eb..d100d96583 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -176,31 +176,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 { - template - static constexpr auto GetNXdlPerWave() - { - constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32; - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); - static_assert(MWaves > 0); - - constexpr index_t NWaves = Waves / MWaves; - if constexpr(NWaves == 0) - { - return 0; - } - else - { - if constexpr(NPerBlock % (NPerXDL * NWaves) == 0) - { - return NPerBlock / (NWaves * NPerXDL); - } - else - { - return 0; - } - } - } // GridwiseGemm + GET_NXDL_PER_WAVE_IMPL static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); @@ -284,6 +261,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 float RunImp(const typename GridwiseGemm::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) @@ -760,31 +742,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) - { - return RunImp(arg, stream_config); - } - } - else - { - if constexpr(NXdlPerWave32 > 0) - { - return RunImp( - reinterpret_cast(arg), - stream_config); - } - } - return 0; - } + INVOKER_RUN3_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -801,11 +759,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2()) { return false; } - if(arg.KBatch > 1) { if(is_gfx11_supported()) @@ -824,14 +781,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 || @@ -855,10 +804,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(arg)); } - else - { - return false; - } } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index faa235be50..ebd168a7d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -77,8 +77,13 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, CLayout, @@ -100,7 +105,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -149,7 +156,9 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -175,7 +184,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -377,6 +386,8 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle()) + { + return false; + } + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -411,7 +425,22 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -516,9 +545,9 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -109,7 +114,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -156,7 +163,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -181,7 +190,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -622,6 +631,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale()) + { + return false; + } + + if(is_gfx11_supported() && arg.KBatch > 1) { return false; } @@ -655,8 +670,23 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { - return GridwiseGemm::CheckValidity(arg); + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -783,7 +813,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale { - // GridwiseGemm - using GridwiseGemm = conditional_t< // - !is_same_v, - GridwiseGemmMX_xdl_cshuffle_v3< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>, - GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>>; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - using Argument = typename GridwiseGemm::Argument; + // GridwiseGemm + template + using GridwiseGemmMXBase = GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + template + using GridwiseGemmMXBPreshuffleBase = GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using GridwiseGemm64 = conditional_t< // + !is_same_v, + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; + using GridwiseGemm32 = conditional_t< // + !is_same_v, + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; + + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -299,7 +314,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX rotating_mem( + ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); rotating_mem.Print(); @@ -401,6 +416,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -578,9 +609,9 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); using PassThrough = ck::tensor_operation::element_wise::PassThrough; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, @@ -112,7 +117,7 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -152,17 +159,17 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 StrideDs_, index_t StrideC_, index_t k_batch_) - : GridwiseGemm::Argument(p_a_grid_, - p_b_grid_, - reinterpret_cast(p_c_grid_), - M_, - N_, - K_, - StrideA_, - StrideB_, - StrideC_, - k_batch_, - true), + : GridwiseGemm64::Argument(p_a_grid_, + p_b_grid_, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + true), p_ds(p_ds_), StrideDs(StrideDs_) { @@ -278,9 +285,10 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 + float RunImp(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) { - auto arg = *dynamic_cast(&arg_); + auto arg = *reinterpret_cast(&arg_); if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && std::is_same::value)) @@ -542,6 +550,8 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1()) { return false; } - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -571,7 +580,22 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -677,7 +701,7 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -380,7 +384,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -406,7 +411,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -424,14 +429,16 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + NXdlPerWave_, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + using Block2CTileMap = typename GridwiseGemm64::DefaultBlock2CTileMap; // Argument struct Argument : public BaseArgument @@ -464,26 +471,12 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - c0_grid_desc_nblock_nperblock_{}, block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, acc_element_op_{acc_element_op}, c_element_op_{c_element_op} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - - c0_grid_desc_nblock_nperblock_ = - GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); - } } // private: @@ -498,9 +491,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_N c0_grid_desc_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; Block2CTileMap block_2_ctile_map_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -513,7 +503,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -538,7 +529,12 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + auto c0_grid_desc_nblock_nperblock = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(arg.c0_grid_desc_n_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); @@ -565,28 +561,27 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator Block2CTileMap, true>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_bias_, - arg.p_c0_grid_add_, - arg.p_c0_grid_gamma_, - arg.p_c0_grid_beta_, - arg.a_element_op_, - arg.b_element_op_, - arg.acc_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + arg.block_2_ctile_map_); } else { @@ -605,33 +600,33 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, Block2CTileMap, false>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_bias_, - arg.p_c0_grid_add_, - arg.p_c0_grid_gamma_, - arg.p_c0_grid_beta_, - arg.a_element_op_, - arg.b_element_op_, - arg.acc_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.c0_grid_desc_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + arg.block_2_ctile_map_); } return ave_time; } + INVOKER_RUN_IMPL // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -648,15 +643,31 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index 7315fe75a3..bc192b7651 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -63,6 +63,10 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -188,7 +192,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -207,7 +212,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -246,8 +253,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -323,7 +311,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 2666051c86..ef4593c320 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,6 +75,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -86,7 +90,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -107,7 +112,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -154,20 +161,20 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { - Print(karg); + Print(arg); } + typename GridwiseGemm::Argument karg(arg.p_a_grid, + arg.p_b_grid, + arg.p_c_grid, + arg.M, + arg.N, + arg.K, + arg.StrideA, + arg.StrideB, + arg.StrideC, + arg.MPadded, + arg.NPadded, + arg.KPadded, + arg.K0Padded, + arg.k_batch); const auto kbatch = karg.k_batch; - if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( @@ -227,9 +248,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(karg), b2c_map, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + arg.a_element_op, + arg.b_element_op, + arg.c_element_op); }; if(has_main_k0_block_loop) @@ -294,6 +315,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK()) + { + return false; + } + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -347,10 +389,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK PipelineVersionToString{{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}}; - str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] + str << GridwiseGemm64::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] << ", PipelineVersion: " << PipelineVersionToString[PipelineVer]; return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp index eda966c48a..f9e164e7b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -70,12 +70,17 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + template + using GridwiseGemmBase = GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, @@ -96,7 +101,7 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - struct Argument : public GridwiseGemm::Argument + struct Argument : public GridwiseGemm64::Argument { Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, @@ -134,20 +141,20 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK + void Print(const Argument_& karg) + { + karg.Print(); + } - void Print(const Argument& karg) { karg.Print(); } - - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -175,8 +186,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK(&karg); + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error( "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " @@ -199,17 +210,16 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK(karg), - b2c_map, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) @@ -274,6 +284,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK()) { return false; } - return GridwiseGemm::CheckValidity(karg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -327,10 +358,10 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; // Invoker struct Invoker : public BaseInvoker { - void Print(const Argument& karg) { karg.Print(); } + template + void Print(const Argument_& karg) + { + karg.Print(); + } - float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& karg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -204,6 +217,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); - if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == - StreamKReductionStrategy::Reduction) + if(get_warp_size() == 64) { - return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc)); + if constexpr(GridwiseGemm64::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size( + sizeof(typename GridwiseGemm64::FloatAcc)); + } } else { - return 0; + if constexpr(GridwiseGemm32::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size( + sizeof(typename GridwiseGemm32::FloatAcc)); + } } + return 0; } void SetWorkSpacePointer(BaseArgument* pArg, @@ -243,11 +268,26 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) { return false; } - return GridwiseGemm::CheckValidity(karg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(karg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(karg)); + } + } + return false; } // polymorphic @@ -270,12 +310,38 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK; - int occupancy, num_cu; + int num_cu; hipError_t rtn; - rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); - hip_check_error(rtn); + int occupancy = [&]() { + int occupancy_ = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm64::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm32::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + return occupancy_; + }(); hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -316,12 +382,39 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK; - int occupancy, num_cu; + int num_cu; hipError_t rtn; - rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); - hip_check_error(rtn); + + int occupancy = [&]() { + int occupancy_ = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm64::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_gemm_xdlops_streamk; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy_, + kernel, + BlockSize, + GridwiseGemm32::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + } + } + return occupancy_; + }(); hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -352,7 +445,11 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - e_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + e_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -132,6 +135,11 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm { + static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + using DeviceOp = DeviceGemm_Xdl_WaveletModel_CShuffle; static constexpr auto I0 = Number<0>{}; @@ -201,7 +209,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< + template + using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, // TODO: distinguish A/B datatype GemmAcEDataType, CShuffleDataType, @@ -224,7 +233,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -277,22 +288,14 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(MRaw, NRaw, StrideE)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} { - if(GridwiseGemm::CheckValidity( - a_grid_desc_m_k_, b_grid_desc_n_k_, e_grid_desc_m_n_, block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } } void Print() const @@ -316,8 +319,6 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -359,6 +361,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm()) { return false; } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + return false; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 5449525306..68e63fcb5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,47 +37,50 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto contraction_arg_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(contraction_args)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - - while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && - block_id < contraction_arg_ptr[group_id].block_end_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < contraction_arg_ptr[group_id].block_start_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - contraction_arg_ptr[group_id].p_a_grid_, - contraction_arg_ptr[group_id].p_b_grid_, - contraction_arg_ptr[group_id].p_ds_grid_, - contraction_arg_ptr[group_id].p_e_grid_, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, - contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, - contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - contraction_arg_ptr[group_id].block_2_etile_map_); + const index_t block_id = get_block_1d_id(); + + const auto contraction_arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(contraction_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && + block_id < contraction_arg_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < contraction_arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + contraction_arg_ptr[group_id].p_a_grid_, + contraction_arg_ptr[group_id].p_b_grid_, + contraction_arg_ptr[group_id].p_ds_grid_, + contraction_arg_ptr[group_id].p_e_grid_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].block_2_etile_map_); + } #else ignore = contraction_args; ignore = group_count; @@ -165,6 +168,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceGroupedContractionMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -357,7 +363,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -378,7 +385,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -400,31 +407,33 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; struct GroupedContractionBlock2ETileMap { // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) { - default_block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + default_block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); block_start_ = BlockStart; } @@ -457,7 +466,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for block/thread-wise copy @@ -531,7 +540,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle contraction_descs[i].b_ns_ks_lengths, contraction_descs[i].b_ns_ks_strides); DsGridDesc_M_N ds_grid_desc_m_n; - typename GridwiseGemm::DsGridPointer p_ds_grid; + typename GridwiseGemm64::DsGridPointer p_ds_grid; // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto j) { @@ -550,19 +559,19 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle contraction_descs[i].e_ms_ns_lengths, contraction_descs[i].e_ms_ns_strides); const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n); const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n); const index_t grid_size_grp = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) .CalculateGridSize(e_grid_desc_m_n); const index_t BlockStart = grid_size_; @@ -592,11 +601,30 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle const index_t e_nz_stride = contraction_descs[i].e_ms_ns_strides[NumDimM + NumDimN - 1]; - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map); + } + } + if(valid) { contraction_multi_d_kernel_args_.push_back( {p_a_grid, @@ -642,7 +670,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -701,6 +730,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -711,11 +742,10 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - for(std::size_t i = 0; i < arg.group_count_; i++) { const auto a_grid_desc_m_k_ = arg.contraction_multi_d_device_args_[i].a_grid_desc_m_k_; @@ -744,11 +774,30 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle const auto ds_nz_stride_ = arg.contraction_multi_d_device_args_[i].ds_nz_stride_; const auto e_nz_stride_ = arg.contraction_multi_d_device_args_[i].e_nz_stride_; - if(!GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + } + if(!valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 25923235c3..57ea476ced 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -96,83 +96,64 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t KBatch) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); - - const long_index_t a_batch_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - const long_index_t a_n_offset = - CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t b_n_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; - - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = DsPointer::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - index_t left = 0; - index_t right = gemms_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && - block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); + + const long_index_t a_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsPointer::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) { - right = group_id; + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } - if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) - { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); - } - else - { - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid + a_batch_offset + a_n_offset, p_b_grid + b_batch_offset + b_n_offset, p_ds_grid_grp, @@ -191,22 +172,44 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) } else { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } } } #else @@ -316,6 +319,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : 32; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; @@ -436,7 +442,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 #define GridwiseGemmMultiDTemplateParams \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ @@ -451,7 +457,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 #define GridwiseGemmCTransposeTemplateParameters \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \ BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ @@ -463,17 +469,26 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; - using GridwiseGemmCTranspose = std::conditional_t< - CTranspose, - GridwiseGemmMultipleD_xdl_cshuffle, - GridwiseGemm>; + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmCTransposeBase = + GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + + using GridwiseGemmCTranspose64 = + std::conditional_t, + GridwiseGemm64>; + using GridwiseGemmCTranspose32 = + std::conditional_t, GridwiseGemm32>; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + return GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n); } @@ -504,14 +519,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map using Block2ETileMap = - decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); + decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; @@ -917,14 +932,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto GemmK = a_grid_desc_m_k.GetLength(I1); const bool HasMainKBlockLoop = - GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_); + GridwiseGemmCTranspose64::CalculateHasMainKBlockLoop(GemmK, k_batch_); gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = GemmArgs{a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - GridwiseGemmCTranspose:: + GridwiseGemmCTranspose64:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n), MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1069,7 +1084,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptor for problem definition @@ -1122,7 +1137,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; - template + template float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1281,7 +1298,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1372,12 +1390,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 if constexpr(IsSplitKSupported) { ave_time += - RunMultiDGemm(arg, stream_config); + RunMultiDGemm(arg, stream_config); } } else { - ave_time += RunMultiDGemm(arg, stream_config); + ave_time += RunMultiDGemm(arg, stream_config); } // Transpose from NHWGC to NGCHW @@ -1427,6 +1449,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -1437,11 +1484,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + // gfx11 doesn't support float atomic + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + return false; + } + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.k_batch_ > 1) { @@ -1598,18 +1649,46 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // Gridwise GEMM size + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - if(!GridwiseGemmCTranspose::CheckValidity( - arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum] - .block_2_ctile_map_, - arg.k_batch_)) + bool valid = true; + if(isWave64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + if(!GridwiseGemmCTranspose64::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum] + [i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) + { + valid = false; + } + } + else + { + if(!GridwiseGemmCTranspose32::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum] + [i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) + { + valid = false; + } + } + if(!valid) + { + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index b761939642..934dc7ee8e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -61,31 +61,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -125,8 +129,8 @@ template { using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -298,7 +305,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, @@ -314,11 +322,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -351,6 +359,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto MakeElementwiseInputSequence() { @@ -539,14 +549,15 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; int max_occupancy = 0; @@ -568,7 +579,26 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle true>, BlockSize, dynamic_smem_size)); - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -605,7 +635,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle a_grid_desc_kbatch_k0_m_k1_{}, b_grid_desc_kbatch_k0_n_k1_{}, ce_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, @@ -695,7 +724,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); + GridwiseGemm64::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; @@ -708,16 +737,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - ce_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( - ce_grid_desc_m_n_); - } } std::size_t GetWorkspaceSizeBytes() const @@ -733,7 +752,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; Block2CTileMap block_2_ctile_map_; @@ -786,7 +804,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, @@ -797,6 +816,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ce_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); @@ -843,7 +865,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle arg.Conv_G_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -900,6 +922,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -915,7 +939,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -977,10 +1001,27 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 4565074b3e..e38768b2fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -57,33 +57,36 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) { -#if defined(__gfx9__) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -112,38 +115,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -166,8 +172,8 @@ template ); using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -393,7 +402,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using CElementwiseGridDesc_M_N = remove_cvref_t())>; - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -412,10 +422,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle KPerBlock, K1, K1, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -440,6 +450,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -503,12 +515,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = @@ -549,7 +562,26 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle BlockSize, dynamic_smem_size)); } - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -704,10 +736,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ce_grid_desc_m_n_, - GridwiseGemm::CalculateMBlock(GemmM), - GridwiseGemm::CalculateNBlock(GemmN)); + GridwiseGemm64::CalculateMBlock(GemmM), + GridwiseGemm64::CalculateNBlock(GemmN)); if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) @@ -853,6 +885,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } + template float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); @@ -1506,7 +1539,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; auto launch_elementwise_kernel = [&]() { @@ -1626,11 +1660,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle grid_size_a); } - avg_time += RunGemmV3(arg, stream_config); + avg_time += RunGemmV3(arg, stream_config); avg_time += launch_elementwise_kernel(); return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1651,13 +1687,44 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const index_t GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + if(get_warp_size() == 64) { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else { return false; } @@ -1676,7 +1743,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } return false; } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 488dadf512..b361409e38 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -59,28 +59,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -119,8 +124,8 @@ template { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -360,7 +368,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle I1, I0>; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, @@ -376,11 +385,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, K1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -413,17 +422,20 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + static int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; int max_occupancy = 0; @@ -445,7 +457,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle false>, // Both true/false give the same occupancy. BlockSize, dynamic_smem_size)); - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -477,7 +507,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle a_grid_desc_kbatch_k0_m_k1_{}, b_grid_desc_kbatch_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, @@ -563,22 +592,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle c_grid_desc_m_n_ = descs[I2]; block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - } - if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) { @@ -671,7 +691,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; Block2CTileMap block_2_ctile_map_; @@ -731,7 +750,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; @@ -739,6 +759,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const BDataType* p_b_grid = arg.p_b_grid_; CDataType* p_e_grid = arg.p_c_grid_; + auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + arg.c_grid_desc_m_n_); + if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -845,7 +869,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.Conv_G_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -893,6 +917,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -908,7 +934,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1023,10 +1049,27 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 0793285dbd..8bf188be2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -55,32 +55,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block) { -#if defined(__gfx9__) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -113,38 +116,41 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -173,8 +179,8 @@ template ); using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); using ADataType = OutDataType; using BDataType = InDataType; @@ -330,7 +339,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -349,10 +359,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 K0PerBlock, K1, K1, - MPerXdl, - NPerXdl, + MPerXDL, + NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -377,15 +387,18 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); struct ActiveWorkgroupsPerCU { - ActiveWorkgroupsPerCU() + template + static int GetMaxOccupancy() { constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = @@ -424,7 +437,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlockSize, dynamic_smem_size)); } - max_occupancy_ = std::max(1, max_occupancy); + return std::max(1, max_occupancy); + } + + ActiveWorkgroupsPerCU() + { + max_occupancy_ = 1; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + max_occupancy_ = GetMaxOccupancy(); + } + } } int max_occupancy_; }; @@ -556,10 +588,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_, - GridwiseGemm::CalculateMBlock(GemmM), - GridwiseGemm::CalculateNBlock(GemmN)); + GridwiseGemm64::CalculateMBlock(GemmM), + GridwiseGemm64::CalculateNBlock(GemmN)); } const ADataType* p_a_grid_; @@ -617,7 +649,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); @@ -1225,6 +1258,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 return ave_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1245,23 +1280,57 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + if(get_warp_size() == 64) { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else + { + return false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + else { return false; } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(is_gfx11_supported() && arg.k_batch_ > 1) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.k_batch_ > 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 1448914dd3..1412c960c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -101,106 +101,111 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if defined(__gfx9__) - - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); - - if constexpr(isMultiA || isMultiB) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - AsPointer p_as_grid_grp; - BsPointer p_bs_grid_grp; + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - // compute_ptr_offset_of_n_ not need BatchStrideB so - // in case of MultiA is false but isMultiB is true - // BatchStrideA_ is not tuple. - if constexpr(isMultiA) + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); + + if constexpr(isMultiA || isMultiB) { - const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}([&](auto i) { - p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; - }); + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; + }); + } + + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); } else { - const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); - static_for<0, 1, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); + const long_index_t b_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t a_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) + : 0; + const long_index_t a_n_offset = + CTranspose ? 0 + : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + + GridwiseGemm::template Run( + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); } - - const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); - - static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); - static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); - - GridwiseGemm::template Run( - p_as_grid_grp, - p_bs_grid_grp, - p_ds_grid_grp, - p_e_grid + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); - } - else - { - const long_index_t b_group_offset = - CTranspose - ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t a_group_offset = - CTranspose - ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) - : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_n_offset = - CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; - const long_index_t a_n_offset = - CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - - GridwiseGemm::template Run( - p_as_grid + a_group_offset + a_n_offset, - p_bs_grid + b_group_offset + b_n_offset, - p_ds_grid_grp, - p_e_grid + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); } #else ignore = p_as_grid; @@ -316,6 +321,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static_assert(NumGroupsToMerge >= 1); @@ -478,7 +486,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ @@ -495,7 +503,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ @@ -511,7 +519,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ - MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ @@ -524,14 +532,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - using GridwiseGemm = std::conditional_t< - isMultiA || isMultiB, - GridwiseGemmMultipleABD_xdl_cshuffle, - GridwiseGemmMultipleD_xdl_cshuffle>; - using GridwiseGemmCTranspose = std::conditional_t< - CTranspose, - GridwiseGemmMultipleD_xdl_cshuffle, - GridwiseGemm>; + template + using GridwiseGemmMultipleABDBase = + GridwiseGemmMultipleABD_xdl_cshuffle; + template + using GridwiseGemmMultipleDBase = + GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmMultipleDCTransposeBase = + GridwiseGemmMultipleD_xdl_cshuffle; + + using GridwiseGemm64 = + std::conditional_t, + GridwiseGemmMultipleDBase>; + using GridwiseGemm32 = std::conditional_t, + GridwiseGemmMultipleDBase>; + + using GridwiseGemmCTranspose64 = + std::conditional_t, + GridwiseGemm64>; + using GridwiseGemmCTranspose32 = + std::conditional_t, + GridwiseGemm32>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = @@ -541,27 +567,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< - decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>; using BGridPointer = remove_cvref_t< - decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))>; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -643,6 +669,62 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Argument struct Argument : public BaseArgument { + template + void InitGridDesc() + { + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + bool valid = false; + if constexpr(CTranspose) + { + valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, + a_grid_desc_m_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + else + { + valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + if(valid) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + } + } + }; + Argument(APointers p_as, BPointers p_bs, const std::array& p_ds, @@ -709,13 +791,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{ - GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -822,58 +904,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle e_g_n_k_wos_strides_[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - // populate desc for Ds/E - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); - - if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(NXdlPerWave64 > 0) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + InitGridDesc(); } } else { - bool valid = false; - if constexpr(CTranspose) + if constexpr(NXdlPerWave32 > 0) { - valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, - a_grid_desc_m_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_); - } - else - { - valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_); - } - if(valid) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + InitGridDesc(); } } - if constexpr(NeedTransposeKernel) { // Use not modified base strides @@ -970,7 +1014,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // pointers (tuple if multi AB, pointer if no) AGridPointer p_as_grid_; BGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // for checking IsSupportedArgument() @@ -1032,6 +1076,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { using Argument = DeviceOp::Argument; + template float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -1232,7 +1277,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; @@ -1285,7 +1331,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle a_grid_size); } - avg_time += RunGemm(arg, stream_config); + avg_time += RunGemm(arg, stream_config); if constexpr(NeedTransposeKernel) { @@ -1324,6 +1370,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return avg_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + else + { + return 0; + } + } + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1350,7 +1421,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } } - if(!ck::is_xdl_supported()) + + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1616,38 +1688,85 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // check Gridwise GEMM - if constexpr(isMultiA || isMultiB) + if(get_warp_size() == 64) { - // Genarate tuples with the same descriptors - const auto as_grid_desc_ak0_m_ak1 = - generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); - const auto bs_grid_desc_bk0_n_bk1 = - generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); - return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(NXdlPerWave64 > 0) + { + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm64::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose64::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + } } else { - if constexpr(CTranspose) + + if constexpr(NXdlPerWave32 > 0) { - return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_, - arg.a_grid_desc_m_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - else - { - return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm32::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose32::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } } } + + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index bb31d64a93..dd2e429a01 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -82,54 +82,57 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; - using DsGridPointer = typename GridwiseGemm::DsGridPointer; - DsGridPointer p_ds_grid_grp{}; + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; - }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - GridwiseGemm::template Run( - karg.p_a_grid + a_group_offset + a_n_offset, - karg.p_b_grid + b_group_offset, - p_ds_grid_grp, - karg.p_c_grid + e_group_offset + e_n_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - block_2_ctile_map, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n, - c_grid_desc_m_n); + GridwiseGemm::template Run( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -164,58 +167,61 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; - using DsGridPointer = typename GridwiseGemm::DsGridPointer; - DsGridPointer p_ds_grid_grp{}; + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; - }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; - const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_group_offset + a_n_offset, - karg.p_b_grid + b_group_offset, - p_ds_grid_grp, - karg.p_c_grid + e_group_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op, - block_2_ctile_map, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_m_n, - c_grid_desc_m_n); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -318,6 +324,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -469,7 +478,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 remove_cvref_t; // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, DsLayout, @@ -493,7 +503,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -521,6 +531,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ADataType, BDataType, DoElementwiseBeforeCShuffle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // #undef GridwiseGemmV3TemplateParams @@ -860,6 +872,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { using Argument = DeviceOp::Argument; + template float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) @@ -1232,7 +1245,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return ave_time; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; if constexpr(!isMultiABD) @@ -1288,7 +1302,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 a_grid_size); } - avg_time += RunGemm(arg, stream_config); + avg_time += RunGemm(arg, stream_config); // Transpose result back to NGCHW if constexpr(is_NGCHW_GKCYX_NGKHW() || @@ -1330,6 +1344,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return avg_time; } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -1373,7 +1389,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -1383,7 +1399,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } return false; } - // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -1628,23 +1643,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{nullptr, - nullptr, - {}, - nullptr, - GemmM, - GemmN, - GemmK, - I0, - I0, - {}, - I0, - I1 /*KBatch*/, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + typename GridwiseGemm64::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + return GridwiseGemm64::CheckValidity(gemm_arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + typename GridwiseGemm32::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + return GridwiseGemm32::CheckValidity(gemm_arg); + } + } - return GridwiseGemm::CheckValidity(gemm_arg); + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index d7859dbc46..adcda93720 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -155,55 +155,60 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if defined(__gfx9__) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - DsPointer p_ds_grid_grp; + DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - RsPointer p_rs_grid_grp; + RsPointer p_rs_grid_grp; - static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); + static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); - static_for<0, NumRTensor, 1>{}( - [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); + static_for<0, NumRTensor, 1>{}( + [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_rs_grid_grp, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - qs_element_op, - rs_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - rs_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_rs_grid_grp, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + rs_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -299,6 +304,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle QsElementwiseOperation> { using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumRTensor = RsDataType::Size(); @@ -429,7 +437,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = remove_cvref_t({}, {}))>; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + template + using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -459,7 +468,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -482,15 +491,17 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument @@ -546,13 +557,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle r_grid_desc_m_{ DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - rs_grid_desc_mblock_mperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_batch_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -577,58 +585,39 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - e_grid_desc_m_n_, - r_grid_desc_m_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; - // populate pointer, batch stride, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); - // D pointer - p_ds_grid_(i) = static_cast(p_ds[i]); + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; - ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i], - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads}; + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); - // D desc - ds_grid_desc_m_n_(i) = - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; - ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_(i)); - }); - - // populate pointer for Rs - static_for<0, NumRTensor, 1>{}([&](auto i) { - using RDataType = remove_cvref_t>; - - // R pointer - p_rs_grid_(i) = static_cast(p_rs[i]); - - rs_grid_desc_mblock_mperblock_(i) = - GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); - }); - } + // R pointer + p_rs_grid_(i) = static_cast(p_rs[i]); + }); } void Print() const @@ -644,9 +633,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; - typename GridwiseGemm::RsGridPointer p_rs_grid_; + typename GridwiseGemm64::RsGridPointer p_rs_grid_; ConvToGemmFwdTransformer conv_to_gemm_transformer_; @@ -660,16 +649,6 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - StaticallyIndexedArray< - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - NumDTensor> - ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different - // type from E - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - StaticallyIndexedArray - rs_grid_desc_mblock_mperblock_; // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -703,7 +682,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, @@ -715,6 +695,32 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock = {}; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock = {}; + + auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.e_grid_desc_m_n_); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(i) = + GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + arg.ds_grid_desc_m_n_(i)); + }); + + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + rs_grid_desc_mblock_mperblock(i) = + GridwiseGemm64::MakeRGridDescriptor_MBlock_MPerBlock(arg.r_grid_desc_m_); + }); + const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.a_g_n_c_wis_lengths_[0]; // Group count @@ -767,9 +773,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.rs_grid_desc_mblock_mperblock_, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, arg.block_2_etile_map_, arg.compute_ptr_offset_of_batch_); }; @@ -784,6 +790,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -794,7 +802,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { namespace ctc = tensor_layout::convolution; - + if(!is_xdl_wmma_supported()) + { + return false; + } // check device if(get_device_name() == "gfx908") { @@ -812,9 +823,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } + else if(ck::is_gfx12_supported() || ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } else { - return false; + // return false; } // check ConvolutionForwardSpecialization @@ -952,11 +970,29 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } // check Gridwise GEMM - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.e_grid_desc_m_n_, - arg.r_grid_desc_m_, - arg.block_2_etile_map_); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + } + return false; } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 8f3feee1c1..25afe46690 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -52,67 +52,70 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_n) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const long_index_t e_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - index_t left = 0; - index_t right = gemms_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && - block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) { - right = group_id; + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); + + using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); + DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = DsPointer::Size(); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = + gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; + }); + + GridwiseGemm::template Run( + gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, + gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, + p_ds_grid_grp, + gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].block_2_etile_map_); } - - using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); - DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = DsPointer::Size(); - static_for<0, NumDTensor, 1>{}([&](auto i) { - p_ds_grid_grp(i) = - gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; - }); - - GridwiseGemm::template Run( - gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, - gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, - p_ds_grid_grp, - gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_kernel_args[group_id].block_2_etile_map_); #else ignore = gemm_desc_kernel_args; ignore = gemms_count; @@ -199,6 +202,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; @@ -412,25 +418,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + template + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Structure for each gemm(conv) struct GemmArgs { @@ -455,6 +464,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Argument struct Argument : public BaseArgument { + template + void init_gemm_args(const ADataType* a_ptr, + const BDataType* b_ptr, + DsPointer ds_ptr, + EDataType* e_ptr, + const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N_& ds_grid_desc_m_n, + const EGridDescriptor_M_N_& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map, + index_t BlockStart, + index_t BlockEnd) + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + gemm_desc_kernel_args_(valid_gemms_count_) = + GemmArgs{a_ptr, + b_ptr, + ds_ptr, + e_ptr, + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } + } Argument(const void* p_a, const void* p_b, const std::array& p_ds, @@ -543,7 +589,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number{}); const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); const index_t grid_size_grp = block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); @@ -553,28 +599,39 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor grid_size_ += grid_size_grp; - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) + if(get_warp_size() == 64) { - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - ds_grid_ptrs[i], - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; - - valid_gemms_count_++; + if constexpr(NXdlPerWave64 > 0) + { + init_gemm_args(a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_gemm_args(a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } } } // N is the same for all convs @@ -649,7 +706,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Invoker struct Invoker : public BaseInvoker { - float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + + using Argument = DeviceOp::Argument; + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -703,6 +763,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor } } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -754,11 +816,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return false; } } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 764daf1750..f6ec0908eb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,94 +45,97 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t KBatch = 1; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) - return; - - const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; - const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; - const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); - constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); - constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); - - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t; - p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t; - p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t; - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - while(id_local < local_grid_size) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm:: - template Run( - p_as_grid_, - p_bs_grid_, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - M, - N, - K, - StrideAs, - StrideBs, - StrideDs, - StrideE, - block_2_etile_map); + const index_t KBatch = 1; - id_off += grid_size_grp; - id_local += grid_size_grp; + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; + const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + GridwiseGemm:: + template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } } #else ignore = gemm_descs_const; @@ -203,6 +206,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK CDEElementwiseOperation> { using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); @@ -215,7 +221,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static constexpr index_t NumGemmKPrefetchStage = 1; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeType, @@ -237,7 +244,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -259,7 +266,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops { @@ -508,7 +516,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); // block-to-e-tile map @@ -547,7 +555,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } const auto e_grid_desc_sum_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( + GridwiseGemm64::template MakeEGridDescriptor_M_N( sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -581,7 +589,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -667,6 +676,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK return ave_time; } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 7b5dd55a8f..4188bf537d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -91,6 +91,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage CDEElementwiseOperation> { using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -105,7 +108,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage using WorkspaceDataType = float; // First stage GridwiseGEMM kernel. - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + template + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -126,7 +130,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage NPerXDL, AK1, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -150,7 +154,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage LoopSched, PipelineVer, ComputeDataType>; - + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) { @@ -220,8 +225,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage Number{}); } - using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; - using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); using DsGridPointer = decltype(MakeDsGridPointer()); using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); @@ -258,7 +263,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage // Block2CTileMap configuration parameter. static constexpr index_t B2E_M01 = 8; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; - using GemmKernelArgument = typename GridwiseGemm::Argument; + using GemmKernelArgument = typename GridwiseGemm64::Argument; struct GemmTransKernelArg { @@ -355,12 +360,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_e = gemm_descs[i].stride_C_; - const index_t m_padded = GridwiseGemm::CalculateMPadded(M); - const index_t n_padded = GridwiseGemm::CalculateNPadded(N); - const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); - const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + const index_t m_padded = GridwiseGemm64::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm64::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm64::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(K, K_BATCH); - const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + const auto c_grid_desc_m_n = + GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, stride_e); DsGridDesc_M_N ds_grid_desc_m_n; DsGridPointer p_ds_grid; @@ -441,11 +447,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage { auto& karg = gemm_kernel_args_[i].karg_; - const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); - const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + const index_t k_padded = GridwiseGemm64::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(karg.K, K_BATCH); const auto c_grid_desc_m_n = - GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + GridwiseGemm64::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; @@ -565,13 +571,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// /// @return The average kernel execution time (if time measurement is enabled.) /// + template float Run(const Argument& arg, void* dev_gemm_args, void* dev_gemm_workspace, const StreamConfig& stream_config = StreamConfig{}) { auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = - CheckArgument(arg, stream_config); + CheckArgument(arg, stream_config); if(dev_gemm_args == nullptr) { @@ -593,13 +600,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(all_have_main_k_block_loop) { - ave_time = - DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + ave_time = DispatchKernel( + arg, dev_gemm_args, dev_gemm_workspace, stream_config); } else { - ave_time = - DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + ave_time = DispatchKernel( + arg, dev_gemm_args, dev_gemm_workspace, stream_config); } return ave_time; @@ -619,7 +626,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// /// @return The average kernel execution time (if time measurement is enabled.) /// - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(arg.p_dev_gemm_kargs_ == nullptr) { @@ -637,9 +645,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); + return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -647,6 +657,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } private: + template auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const { bool all_have_kbatch_gt_one, all_have_main_k_block_loop; @@ -670,7 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { - const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + const auto& gemm_arg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); if(stream_config.log_level_ > 0) { gemm_arg.Print(); @@ -721,7 +733,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); } - template + template float DispatchKernel(const Argument& arg, void* dev_gemm_kargs, void* dev_gemm_workspace, @@ -818,11 +830,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if((ck::type_convert(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -836,11 +847,27 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } bool supported = true; + bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = false; + if(isWave64) + { + if constexpr(NXdlPerWave64 > 0) + { + group_arg_valid = GridwiseGemm64::CheckValidity(gemm_arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + group_arg_valid = GridwiseGemm32::CheckValidity( + reinterpret_cast(gemm_arg)); + } + } - bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 70a395f2f7..d8d688aa06 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,126 +68,91 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if defined(__gfx9__) - - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; - __shared__ uint8_t p_shared1[shared_size]; - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - constexpr auto NumDTensor = DsDataType::Size(); - index_t tile_id = get_block_1d_id(); - index_t tile_offset = 0; - index_t group_id = -1; - index_t group_offset = 0; - index_t grid_size_grp = 0; - - index_t gemm_tile_id_start = 0; - index_t gemm_tile_id_end = 0; - - index_t M = 0, N = 0, K = 0; - - auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); - - do +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - // Find corresponding GEMM group for our tile - while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && - group_id < group_count) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do { - group_offset += grid_size_grp; - group_id++; - - if(group_id >= group_count) - return; - - M = gemm_desc_ptr[group_id].M; - N = gemm_desc_ptr[group_id].N; - K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) { - grid_size_grp = 0; - continue; - } + group_offset += grid_size_grp; + group_id++; - b2c_tile_map = - OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); - grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + if(group_id >= group_count) + return; - gemm_tile_id_start = group_offset; - gemm_tile_id_end = group_offset + grid_size_grp; - } + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - DsGridPointer p_ds_grid; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - static constexpr index_t kbatch = 1; - static constexpr index_t k_grain = kbatch * KPerBlock; - index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - // Update tile offset if we have moved within group - b2c_tile_map.UpdateTileOffset(tile_offset); - - using Problem = typename GridwiseGemm::Problem; - auto problem = Problem(gemm_desc_ptr[group_id].M, - gemm_desc_ptr[group_id].N, - gemm_desc_ptr[group_id].K, - gemm_desc_ptr[group_id].StrideA, - gemm_desc_ptr[group_id].StrideB, - gemm_desc_ptr[group_id].StrideDs, - gemm_desc_ptr[group_id].StrideE, - kbatch); - - if(has_main_k_block_loop) - { - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + if(M == 0 || N == 0 || K == 0) { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); + grid_size_grp = 0; + continue; } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + + b2c_tile_map = OffsettedBlockToCTileMap( + LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { GridwiseGemm::template Run 2) + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { GridwiseGemm::template Run( + TailNumber::One>( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, @@ -224,16 +188,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) cde_element_op, b2c_tile_map); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { GridwiseGemm::template Run( + TailNumber::Full>( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, @@ -245,84 +205,166 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) cde_element_op, b2c_tile_map); } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) { - GridwiseGemm::template Run( + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), + static_cast(p_shared1), problem, a_element_op, b_element_op, cde_element_op, b2c_tile_map); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + else { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) - { - GridwiseGemm::template Run( + GridwiseGemm::template Run_2Lds( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), + static_cast(p_shared1), problem, a_element_op, b_element_op, @@ -331,39 +373,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) } } } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + else { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - GridwiseGemm::template Run_2Lds( + GridwiseGemm::template Run( static_cast(gemm_desc_ptr[group_id].p_a_grid), static_cast(gemm_desc_ptr[group_id].p_b_grid), p_ds_grid, static_cast(gemm_desc_ptr[group_id].p_e_grid), static_cast(p_shared), - static_cast(p_shared1), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - else - { - GridwiseGemm::template Run_2Lds( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - static_cast(p_shared1), problem, a_element_op, b_element_op, @@ -371,32 +393,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) b2c_tile_map); } } - } - else - { - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - GridwiseGemm::template Run( - static_cast(gemm_desc_ptr[group_id].p_a_grid), - static_cast(gemm_desc_ptr[group_id].p_b_grid), - p_ds_grid, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), - problem, - a_element_op, - b_element_op, - cde_element_op, - b2c_tile_map); - } - } - tile_id += get_grid_size(); - tile_offset += get_grid_size(); + tile_id += get_grid_size(); + tile_offset += get_grid_size(); - } while(group_id < group_count); + } while(group_id < group_count); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -467,10 +469,14 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + template + using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, @@ -494,7 +500,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -519,6 +525,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -571,10 +579,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop // The oversubscription factor for the number of blocks that can simultaneously reside on // GPU. static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; - static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); - static constexpr int CU_SIMDS = 4; + // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; // Assume we want to have at most 2 waves per SIMD - static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + static int GetCuBlocks() + { + int BLOCK_WAVES = BlockSize / get_warp_size(); + return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + } }; // Invoker @@ -593,6 +606,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop /// /// @return The average kernel execution time (if time measurement is enabled.) /// + template float Run(const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config = StreamConfig{}) @@ -606,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop } float ave_time = 0; - ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); + ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); return ave_time; } @@ -624,7 +638,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop /// /// @return The average kernel execution time (if time measurement is enabled.) /// - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(arg.p_dev_gemm_args_ == nullptr) { @@ -634,9 +649,11 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_args_, stream_config); + return Run(arg, arg.p_dev_gemm_args_, stream_config); } + INVOKER_RUN_IMPL + float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override { @@ -644,6 +661,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop } private: + template float DispatchKernel(const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config) const @@ -686,11 +704,11 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks << ", available CUs count: " << cu_count << ", occup. grid size: " - << ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count + << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count << std::endl; } - return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS); + return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()); } template @@ -730,36 +748,19 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - bool supported = true; constexpr index_t k_batch = 1; + bool isWave64 = get_warp_size() == 64; for(index_t i = 0; i < arg.group_count_; ++i) { std::array placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); - using GridArg = typename GridwiseGemm::Argument; - GridArg gridwise_arg(nullptr, // p_a_grid, - nullptr, // p_b_grid, - placeholder_p_ds_grid, // p_ds_grid, - nullptr, // p_e_grid , - arg.gemm_descs_[i].M_, - arg.gemm_descs_[i].N_, - arg.gemm_descs_[i].K_, - arg.gemm_descs_[i].stride_A_, - arg.gemm_descs_[i].stride_B_, - stride_Ds, - arg.gemm_descs_[i].stride_C_, - k_batch, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_); - if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || @@ -768,8 +769,62 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { return false; } + if(isWave64) + { + if constexpr(NXdlPerWave64 > 0) + { + using GridArg = typename GridwiseGemm64::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); - supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); + supported = supported && GridwiseGemm64::CheckValidity(gridwise_arg); + } + else + { + supported = false; + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + using GridArg = typename GridwiseGemm32::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + supported = supported && GridwiseGemm32::CheckValidity(gridwise_arg); + } + else + { + supported = false; + } + } } return supported; @@ -780,6 +835,67 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return IsSupportedArgument(*dynamic_cast(p_arg)); } + static int GetKernelOccupancy() + { + int occupancy = 0; + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + } + } + else + { + + if constexpr(NXdlPerWave32 > 0) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + } + } + return occupancy; + } + static auto MakeArgument(std::vector& p_As, std::vector& p_Bs, std::vector>& p_Ds, @@ -789,28 +905,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - int occupancy, num_cu; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + int occupancy = GetKernelOccupancy(); + int num_cu; hipDeviceProp_t dev_prop; hipDevice_t dev; @@ -840,28 +936,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - int occupancy, num_cu; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + int occupancy = GetKernelOccupancy(); + int num_cu; hipDeviceProp_t dev_prop; hipDevice_t dev; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 784b2fd401..62dcbfb83b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,63 +43,70 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto arg_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(group_kernel_args)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - - while( - (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_))) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < arg_ptr[group_id].block_start_) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(group_kernel_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= arg_ptr[group_id].block_start_ && + block_id < arg_ptr[group_id].block_end_))) { - right = group_id; + if(block_id < arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); + + // per-group batch offset + const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; + const index_t g_idx = __builtin_amdgcn_readfirstlane( + (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run( + arg_ptr[group_id].p_a_grid_ + a_batch_offset, + arg_ptr[group_id].p_b_grid_ + b_batch_offset, + arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, + arg_ptr[group_id].p_c_grid_ + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg_ptr[group_id].block_2_ctile_map_, + arg_ptr[group_id].c0_matrix_mask_); } - - // per-group batch offset - const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; - const index_t g_idx = __builtin_amdgcn_readfirstlane( - (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); - const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( - arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - - GridwiseGemm::template Run( - arg_ptr[group_id].p_a_grid_ + a_batch_offset, - arg_ptr[group_id].p_b_grid_ + b_batch_offset, - arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, - arg_ptr[group_id].p_c_grid_ + c_batch_offset, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - b1_element_op, - c_element_op, - arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, - arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, - arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, - arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg_ptr[group_id].block_2_ctile_map_, - arg_ptr[group_id].c0_matrix_mask_); #else ignore = group_kernel_args; ignore = group_count; @@ -198,6 +205,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle CElementwiseOperation, MaskingSpec> { + static constexpr auto MXdlPerWave64 = + GetNXdlPerWave2(); + static constexpr auto MXdlPerWave32 = + GetNXdlPerWave2(); + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -338,7 +350,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + template + using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -365,7 +378,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle B1K1, MPerXDL, NPerXDL, - MXdlPerWave, + MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -399,8 +412,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle LoopSched, Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Block2CTileMap = OffsettedBlockToCTileMap; + using Block2CTileMap = OffsettedBlockToCTileMap; struct GroupKernelArg { @@ -414,7 +429,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; // batch & stride @@ -511,7 +526,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n); const index_t BlockStart = grid_size_; @@ -592,7 +607,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!DeviceOp::IsSupportedArgument(arg)) { @@ -664,6 +680,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return ave_time; } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -680,11 +715,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - // TODO ANT: Check if tensor specialization & strides mismatch bool all_has_main_k_block_loop = true; @@ -708,7 +742,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle // Check if having main loop const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + const bool y = GridwiseGemm64::CalculateHasMainKBlockLoop(K); all_has_main_k_block_loop &= y; some_has_main_k_block_loop |= y; @@ -753,14 +787,31 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } - if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, - kernel_arg.b_grid_desc_bk0_n_bk1_, - kernel_arg.b1_grid_desc_bk0_n_bk1_, - device_arg.c_grid_desc_m_n_, - kernel_arg.block_2_ctile_map_)) + bool valid = false; + if(get_warp_size() == 64) { - return false; + if constexpr(MXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_); + } } + else + { + if constexpr(MXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_); + } + } + if(!valid) + return false; } // all gemm problems have to simultaneously meet has_main_k_block_loop or diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 2c5d1dd134..7a1944cc68 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -39,46 +39,49 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && - block_id < gemm_desc_ptr[group_id].BlockEnd_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - if(block_id < gemm_desc_ptr[group_id].BlockStart_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - gemm_desc_ptr[group_id].a_ptr_, - gemm_desc_ptr[group_id].b_ptr_, - gemm_desc_ptr[group_id].ds_ptr_, - gemm_desc_ptr[group_id].e_ptr_, - p_shared, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, - gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, - gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_desc_ptr[group_id].block_2_etile_map_); + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].block_2_etile_map_); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -145,7 +148,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm { using DeviceOp = DeviceGroupedGemm_Xdl; - + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -231,7 +236,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, ComputeDataType, @@ -252,7 +258,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BK0_N_BK1 = - remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; struct GroupedGemmBlock2ETileMap { using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; GroupedGemmBlock2ETileMap() { - block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); + block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); BlockStart_ = -1; } GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) { - block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); BlockStart_ = BlockStart; } @@ -334,7 +342,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + void init_gridwise_gemm_desc(const ADataType* a_ptr, + const BDataType* b_ptr, + DsPointer ds_ptr, + EDataType* e_ptr, + const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map, + index_t BlockStart, + index_t BlockEnd) + { + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + // tensor descriptors for block/thread-wise copy + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmBiasTransKernelArg{a_ptr, + b_ptr, + ds_ptr, + e_ptr, + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + }; Argument(std::vector& p_As, std::vector& p_Bs, std::vector>& p_Ds, @@ -403,7 +469,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}([&](auto j) { using DDataType = remove_cvref_t>; @@ -427,13 +493,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(M, N, StrideC); - // tensor descriptors for block/thread-wise copy - const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); - - const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - const index_t grid_size_grp = GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); @@ -447,42 +506,41 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}([&](auto j) { - ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n[j]); - }); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n); - - gemm_desc_kernel_arg_.push_back( - GemmBiasTransKernelArg{static_cast(p_As[i]), - static_cast(p_Bs[i]), - p_ds_grid, - static_cast(p_Es[i]), - a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map, - BlockStart, - BlockEnd}); + if constexpr(NXdlPerWave64 > 0) + { + init_gridwise_gemm_desc( + static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_gridwise_gemm_desc( + static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd); + } } } } @@ -508,10 +566,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + float RunImp(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { bool has_main_k_block_loop = true; @@ -626,6 +685,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -636,11 +717,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm()) { return false; } - if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -649,12 +729,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) - return; - - const auto StrideA = gemm_desc_ptr[group_id].StrideA; - const auto StrideB = gemm_desc_ptr[group_id].StrideB; - const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumDTensor = DsDataType::Size(); - - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - - DsGridPointer p_ds_grid_; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - // D pointer - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - const index_t mn_blocks = local_grid_size / KBatch; - - while(id_local < local_grid_size) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - if constexpr(Zeroing) - { - auto barrier_count_finished = - barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; - GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } - else + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - nullptr, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { - id_off += grid_size_grp; - id_local += grid_size_grp; + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + + id_off += grid_size_grp; + id_local += grid_size_grp; + } } #else ignore = gemm_descs_const; @@ -241,6 +244,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK { using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -252,7 +258,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, AComputeType, @@ -274,7 +281,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops @@ -479,7 +488,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; @@ -550,7 +559,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( AverM, N, StrideE); // block-to-e-tile map @@ -566,17 +575,55 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( - AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + if(get_warp_size() == 64) { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + if constexpr(NXdlPerWave64 > 0) + { + if(!GridwiseGemm64::template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid " + "setting"); + } + } + else + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + if(!GridwiseGemm32::template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid " + "setting"); + } + } + else + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } } gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ @@ -597,7 +644,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( + GridwiseGemm64::template MakeEGridDescriptor_M_N( sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -631,7 +678,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { bool has_main_k_block_loop = true; @@ -783,6 +831,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(cast_pointer_to_generic_address_space(gemm_descs_const)); - - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && - block_id < gemm_desc_ptr[group_id].block_end_)) && - left <= right) +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - if(block_id < gemm_desc_ptr[group_id].block_start_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - gemm_desc_ptr[group_id].karg_, - static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_, - a_element_op, - b_element_op, - c_element_op); + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].karg_, + static_cast(p_shared), + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); + } #else ignore = gemm_descs_const; ignore = group_count; @@ -144,6 +147,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -153,7 +159,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, @@ -174,7 +181,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; using Block2ETileMapKSplit = BlockToCTileMap_KSplit_M00_N0_M01Adapt; // Block2CTileMap configuration parameter. static constexpr index_t B2E_M01 = 8; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; - using KernelArgument = typename GridwiseGemm::Argument; + using KernelArgument = typename GridwiseGemm64::Argument; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - struct GemmTransKernelArg + template + struct GemmTransKernelArgBase { - KernelArgument karg_; + KernelArgument_ karg_; GroupedGemmBlock2ETileMap block_2_ctile_map_; index_t block_start_, block_end_; - GemmTransKernelArg() = default; - GemmTransKernelArg(KernelArgument&& karg, - GroupedGemmBlock2ETileMap&& b2c_map, - index_t block_start, - index_t block_end) + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) : karg_{karg}, block_2_ctile_map_{b2c_map}, block_start_{block_start}, @@ -224,6 +234,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; static constexpr index_t DefaultKBatch = 1; @@ -277,12 +288,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + float RunImp(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + static_assert(sizeof(typename GridwiseGemm::Argument) == + sizeof(typename GridwiseGemm64::Argument)); + index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { - const auto& karg = arg.gemm_kernel_args_[i].karg_; + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); if(stream_config.log_level_ > 0) { karg.Print(); @@ -439,7 +458,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -500,7 +519,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -513,7 +532,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -523,7 +542,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; @@ -534,6 +553,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return RunImp(arg, stream_config, cpy_stream, cpy_event); + } + } + return 0; + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -550,11 +591,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK()) + { + return false; + } + if(is_gfx11_supported() && arg.K_BATCH > 1) { return false; } - if((ck::type_convert(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { @@ -573,11 +617,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) + { + group_arg_valid = GridwiseGemm64::CheckValidity(a); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + group_arg_valid = GridwiseGemm32::CheckValidity( + reinterpret_cast(a)); + } + } - bool group_arg_valid = GridwiseGemm::CheckValidity(a); if(not group_arg_valid) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 27d3c378ac..748ae28a50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -87,8 +87,12 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemm; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -167,7 +173,9 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -195,7 +203,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -214,8 +222,13 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -376,6 +389,8 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -418,8 +432,22 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -553,7 +581,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + template + using GridwiseGemmBase = GridwiseMoeGemmBlockScale< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -180,7 +186,9 @@ struct DeviceMoeGemmBlockScale // Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -207,7 +215,7 @@ struct DeviceMoeGemmBlockScale std::array DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -226,8 +234,13 @@ struct DeviceMoeGemmBlockScale using DDataType = remove_cvref_t>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -385,6 +398,8 @@ struct DeviceMoeGemmBlockScale return ave_time; } + INVOKER_RUN3_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -406,11 +421,10 @@ struct DeviceMoeGemmBlockScale { return false; } - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -428,7 +442,22 @@ struct DeviceMoeGemmBlockScale return false; } - return GridwiseGemm::CheckValidity(arg); + if(get_warp_size() == 64) + { + if constexpr(NXdlPerWave64 > 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -572,7 +601,7 @@ struct DeviceMoeGemmBlockScale << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << "BlkGemmPipelinePrefetchStages: " - << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; + << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp index e7be94242b..9c14106033 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp @@ -90,8 +90,12 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemmMX; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; - + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +164,9 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +194,7 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +213,13 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -333,6 +345,8 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -376,7 +389,22 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -530,7 +558,7 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = + template + using GridwiseGemmBase = GridwiseMoeGemmMXBNS; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +165,9 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +195,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +214,13 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -331,6 +344,8 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -374,7 +388,22 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -528,7 +557,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle { + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB>; + template + using GridwiseGemmBase = GridwiseMoeGemmMX_BPreshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave_, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - using Argument = typename GridwiseGemm::Argument; + using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -159,7 +165,9 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle + float RunImp(const typename GridwiseGemm::Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { @@ -187,7 +195,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle DsSize; - Argument arg_ = arg; + auto arg_ = arg; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); @@ -206,8 +214,13 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle>; DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + ck::utility::RotatingMemWrapperMultiD + rotating_mem(arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + DsSize); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -358,6 +371,8 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle()) { return false; } - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) { return false; @@ -401,7 +415,22 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle 0) + { + return GridwiseGemm64::CheckValidity(arg); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + return GridwiseGemm32::CheckValidity( + reinterpret_cast(arg)); + } + } + return false; } // polymorphic @@ -555,7 +584,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - FloatDsPointer p_ds_grid_grp; + FloatDsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_akb_ak0_m_ak1, - b_grid_desc_bkb_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -190,7 +195,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEElementwiseOperation> { using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; - + GET_NXDL_PER_WAVE_IMPL + static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -521,7 +528,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle }; // GridwiseGemm - using GridwiseGemm = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + template + using GridwiseGemmBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -545,7 +553,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -567,9 +575,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // GridwiseGemm - using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + template + using GridwiseGemmAtomicAddBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -593,7 +604,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle MPerXDL, NPerXDL, MXdlPerWave, - NXdlPerWave, + NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -615,19 +626,39 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; using AGridDesc_AKB_AK0_M_AK1 = - remove_cvref_t; using BGridDesc_BKB_BK0_N_BK1 = - remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + using Block2ETileMap = typename GridwiseGemm64::DefaultBlock2ETileMap; // Argument struct Argument : public BaseArgument { + template + void init_ds_e_grid_desc() + { + if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, + b_grid_desc_bkb_bk0_n_bk1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } Argument(const void* p_a_grid, const void* p_b_grid, std::array p_ds_grid, @@ -659,14 +690,14 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, - a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( + a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm64::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( a_grid_desc_m_k_, split_k)}, - b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( + b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm64::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( b_grid_desc_n_k_, split_k)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{ - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, + GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -697,19 +728,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle }); // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, - b_grid_desc_bkb_bk0_n_bk1_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if(get_warp_size() == 64) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if constexpr(NXdlPerWave64 > 0) + { + init_ds_e_grid_desc(); + } + } + else + { + if constexpr(NXdlPerWave32 > 0) + { + init_ds_e_grid_desc(); + } } // for sanity check of vector memory access @@ -755,7 +786,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemm64::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; // tensor descriptors for problem definiton @@ -770,9 +801,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1_; BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseGemm64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map @@ -806,7 +837,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, arg.b_grid_desc_bkb_bk0_n_bk1_, @@ -818,6 +850,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); } + using GridwiseGemmAtomicAdd = + std::conditional_t, + GridwiseGemmAtomicAdd64, + GridwiseGemmAtomicAdd32>; + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); const index_t grid_size = @@ -931,6 +968,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle } } + INVOKER_RUN_IMPL + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -941,19 +980,35 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, - arg.b_grid_desc_bkb_bk0_n_bk1_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) + bool valid = false; + if(get_warp_size() == 64) { - return false; + if constexpr(NXdlPerWave64 > 0) + { + valid = GridwiseGemm64::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } } + else + { + if constexpr(NXdlPerWave32 > 0) + { + valid = GridwiseGemm32::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + if(!valid) + return false; // check vector access static_assert((ABlockTransferSrcVectorDim == 2 || ABlockTransferSrcVectorDim == 3) && diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 7eca68bbf8..b6bc634d74 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -169,7 +169,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ __device__ constexpr bool + CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 36dc8aa6ba..7a09b84a63 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -267,6 +267,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CShuffleDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 70c641531b..a15f11a93f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -195,6 +195,28 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + if constexpr(!valid) + { + return false; + } + + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 84d7b04495..b8f5a545aa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -255,18 +255,63 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + E1DataType, + CGlobalMemoryDataOperation_>() && + ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + E1DataType, + CGlobalMemoryDataOperation_>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, - const B0GridDesc_N_K& b0_grid_desc_n_k, - const B1GridDesc_N_K& b1_grid_desc_n_k, - const E1GridDesc_M_N& e1_grid_desc_m_n, - const Block2E1TileMap& block_2_e1tile_map) + __host__ static constexpr bool CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, + const B0GridDesc_N_K& b0_grid_desc_n_k, + const B1GridDesc_N_K& b1_grid_desc_n_k, + const E1GridDesc_M_N& e1_grid_desc_m_n, + const Block2E1TileMap& block_2_e1tile_map) { static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) && (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr((Gemm0MPerXdl * Gemm0MXdlPerWave) == 0 || + (Gemm0NXdlPerWave * Gemm0NPerXdl) == 0) + { + return false; + } + else + { + if constexpr((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) != 0) || + (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl) != 0)) + { + return false; + } + else + { + if(WaveSize != get_warp_size()) + { + return false; + } + } + } const auto M = a0_grid_desc_m_k.GetLength(I0); const auto N = b0_grid_desc_n_k.GetLength(I0); @@ -527,8 +572,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const CDE1ElementwiseOperation& cde1_element_op, const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const D0sGridDesc_M_N& d0s_griddesc_m_n, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& d1s_grid_desc_mblock_mperblock_nblock_nperblock, @@ -536,6 +580,8 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap& block_2_e1tile_map) { + const auto d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(d0s_griddesc_m_n); const auto a0_grid_buf = make_dynamic_buffer( p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b0_grid_buf = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 222cb3894c..0e8d003071 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -210,6 +210,22 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -445,11 +461,12 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const D0sGridDesc_M_N& d0s_griddesc_m_n, const Block2CTileMap& block_2_ctile_map, const C0MatrixMask& c0_matrix_mask) { + const auto d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(d0s_griddesc_m_n); const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1066,6 +1083,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // main body if constexpr(num_gemm1_k_block_inner_loop > 1) { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, make_tuple(Number{}, I0, I0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 2d00daf7f6..e0cf12e429 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -209,6 +209,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 96b737385a..0a4691b509 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -66,29 +66,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_bias_grid, - p_d0_grid, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - c1_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c0_grid_desc_mblock_mperblock_nblock_nperblock, - c1_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -252,6 +257,22 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c37ffb6263..c198711dbb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -298,6 +298,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template {}, Sequence<1>{})); } + IS_VALID_COMPILATION_PARAMETER_IMPL(FloatE) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 318ff59383..59d7f357ec 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -322,6 +322,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return true; } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -424,6 +428,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using Block2ETileMap = remove_cvref_t; + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n, @@ -433,6 +439,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, "KPerBlock must be divisible by AK1Value and BK1Value!"); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 85b5b5faab..872e1271e1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -351,6 +351,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle e_grid_desc_m_n); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_reduces_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - reduce_in_element_ops, - reduce_out_element_ops, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - reduce_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -225,6 +229,22 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index b4848c7077..8f7aac0171 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -283,6 +283,8 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle e_grid_desc_m_n, 8, split_k); } + IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template ()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -58,20 +61,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid, - karg.p_b_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.p_workspace_); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.p_workspace_); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1012,6 +1018,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 906bfe0912..4cd1a587e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -24,11 +24,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -48,10 +52,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) FloatC* __restrict__ p_c_grid, typename GridwiseGemm::Problem problem) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared, problem); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -542,6 +551,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Problem& problem) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index 57624b218c..ccba4d4a94 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,14 +25,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -52,12 +56,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) FloatC* p_c_grid, typename GridwiseGemm::Problem problem) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -581,6 +589,22 @@ struct GridwiseGemm_xdl_cshuffle_v2 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Problem& problem) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 5545192e3c..a6e4870ac7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1187,74 +1187,23 @@ struct GridwiseGemm_xdl_cshuffle_v3 return false; } - // Check tile size -#if defined(__gfx11__) || defined(__gfx12__) - if constexpr(MPerXdl != 16 || NPerXdl != 16) - { - return false; - } -#endif - // Check atomic caps -#if defined(__gfx11__) - constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set; -#else - constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation == - InMemoryDataOperationEnum::Set); -#endif - if constexpr(SupportMemOp == false) - { - return false; - } - - // Check tile size - if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0) - { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - if constexpr(MWaves > 0 && NWaves > 0) - { - constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); - if constexpr(WaveSize == get_warp_size()) - { - return true; - } - else - { - return false; - } - } - else - { - return false; - } - } - else - { - return false; - } + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation>(); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { - if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0) - { - return false; - } - else - { - if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) || - (NPerBlock % (NXdlPerWave * NPerXdl) != 0)) - { - return false; - } - else - { - if(BlockwiseGemmPipe::WaveSize != get_warp_size()) - { - return false; - } - } - } + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index b99113ef16..78546c4f99 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -35,17 +35,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -63,21 +66,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -885,6 +891,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index a9c7556130..36141bc96f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -35,19 +35,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -65,23 +67,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1049,6 +1053,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 8ca8b7a2b9..35c8c6c3b4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -34,19 +34,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_as_grid, - karg.p_bs_grid, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -64,23 +67,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_as_grid, - karg.p_bs_grid, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1112,6 +1118,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 676da3e925..5b19ff8542 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,25 +73,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1036,6 +1042,47 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } + + using MfmaInst = MfmaSelector; + + constexpr index_t KPerThread = + KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops()); + if constexpr(KPerThread % KPack != 0) + { + static_assert(0); + return false; + } + + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -1043,6 +1090,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index be3c6ebb35..8119cace3b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -443,7 +446,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __host__ __device__ static constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) { - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWaves = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); } @@ -984,6 +988,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index dfcc20b3c2..2e95ec0d52 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -38,21 +38,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,23 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -895,6 +901,46 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } + + using MfmaInst = MfmaSelector; + + constexpr index_t KPerThread = + KPerBlock / (MfmaInst::GetKPerXdlops() / MfmaInst::GetK1PerXdlops()); + if constexpr(KPerThread % KPack != 0) + { + return false; + } + + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + return true; + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -902,6 +948,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index d832bef2da..bf7ae1c6e8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -39,21 +39,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -72,23 +75,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds( - karg.p_a_grid, - karg.p_b_grid, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -860,6 +866,8 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index cb9c354701..1a356c372d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -39,20 +39,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -71,24 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1113,6 +1117,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 3ac9845b66..3d2ef9b6c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -39,20 +39,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -71,24 +73,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx950__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1027,6 +1031,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index e90239b70a..e4152e0427 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,28 +57,31 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - // TODO ANT: separate into MMA + Epilogue - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_bias_grid, - p_c0_add_grid, - p_c0_gamma_grid, - p_c0_beta_grid, - p_shared, - a_element_op, - b_element_op, - acc_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c0_grid_desc_nblock_nperblock, - block_2_ctile_map); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + } // TODO ANT: Run layernorm epilogue here #else ignore = p_a_grid; @@ -243,6 +246,22 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_lds_workspace_size * sizeof(FloatReduceAcc)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 50363d832e..67f18de12f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -77,6 +77,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle static constexpr auto BK1 = Number{}; static constexpr auto AK0PerBlock = Number{}; static constexpr auto BK0PerBlock = Number{}; + static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); struct TileLoadThreadGroup { @@ -171,6 +172,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle c_block_size * sizeof(EDataTypeShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + EDataType, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 344c7d6528..abb8c52e0f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -166,20 +166,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -209,8 +213,8 @@ template + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + FloatC, + CGlobalMemoryDataOperation>(); + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -530,8 +549,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); @@ -604,14 +623,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); } using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -751,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, MfmaSelector::selected_mfma.k_per_blk); @@ -764,8 +783,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, KPack>{}; @@ -807,8 +826,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -834,8 +853,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight static_assert(M1 == MWave, ""); static_assert(N1 == NWave, ""); - static_assert(M2 * M3 * M4 == MPerXDL, ""); - static_assert(N2 == NPerXDL, ""); + static_assert(M2 * M3 * M4 == MPerXdl, ""); + static_assert(N2 == NPerXdl, ""); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_block_desc_mblock_mperblock_nblock_nperblock, @@ -845,11 +864,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -920,9 +939,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -941,11 +960,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 0dbdac85bf..9e524c5a23 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -44,20 +44,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -86,8 +90,8 @@ template static constexpr auto K1 = Number{}; - static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); + static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); static constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; using ThisThreadBlock = ThisThreadBlock; @@ -164,6 +168,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 return (a_block_space_size_aligned) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -175,8 +194,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -242,7 +261,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 make_tuple(make_unmerge_transform( make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), make_unmerge_transform(make_tuple( - N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), + N / (NXdlPerWave * NWaves * NPerXdl), NXdlPerWave, NWaves, NPerXdl)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); @@ -264,7 +283,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 __device__ static auto GetWaveKNIdx(const index_t thread_id) { constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))), + make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXdl))), make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{})); @@ -311,8 +330,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>; @@ -510,8 +529,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>{}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index a13ce732e6..c5f60a7413 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,12 +39,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op) { #if defined(__gfx9__) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + } #else ignore = karg; ignore = b2c_map; @@ -70,8 +73,8 @@ template {}, + Number{}, I1, - Number{})); + Number{})); } // return block_id to C matrix tile idx (m0, n0, k_split) mapping @@ -711,8 +718,8 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1, @@ -766,8 +773,8 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -799,11 +806,11 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -874,9 +881,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -895,11 +902,11 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index 6aa61fcd38..a040409a6d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,23 +37,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_workspace, - M, - N, - K, - StrideA, - StrideB, - StrideC, - block_mapping, - static_cast(p_shared)); + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC, + block_mapping, + static_cast(p_shared)); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -83,8 +86,8 @@ template ::value) @@ -380,27 +387,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk __host__ __device__ static constexpr auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); } __host__ __device__ static constexpr auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl); return make_naive_tensor_descriptor_packed( make_tuple(Number{}, - Number{}, + Number{}, Number{}, - Number{})); + Number{})); } __host__ __device__ static constexpr auto GetClusterLengthReduction() @@ -490,8 +497,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1>{}; @@ -829,8 +836,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -871,12 +878,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform( make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, @@ -948,9 +955,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk CElementwiseOperation, // ElementwiseOperation, // InMemoryDataOperationEnum::Set, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, @@ -977,9 +984,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk CElementwiseOperation, // ElementwiseOperation, // InMemoryDataOperationEnum::Set, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, @@ -1000,11 +1007,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index ae9a8af813..aa7ce1f5b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,16 +38,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -68,25 +72,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - const auto a_grid_desc_k0_m_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); - const auto b_grid_desc_k0_n_k1 = - amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( - karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); - const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( - karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); - GridwiseGemm::template Run(karg.p_a_grid, - karg.p_b_grid, - karg.p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m_n); + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); + } #else ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) @@ -103,8 +111,8 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && K1Value <= 4) || + (is_same::value && K1Value <= 8) || + ((is_same::value || is_same::value) && K1Value < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto K1 = Number:: + selected_mfma.k_per_blk)>{}; using ThisThreadBlock = ThisThreadBlock; @@ -314,6 +332,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -323,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -356,8 +390,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); // check gridwise gemm pipeline @@ -421,8 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1>; @@ -566,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MXdlPerWave, NXdlPerWave, K1, @@ -704,8 +738,8 @@ template >::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index f779e63752..a9a463e2c1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,23 +42,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); +#ifdefined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - __shared__ FloatAB p_shared_block[shared_block_size]; + __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -171,6 +175,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 595a597318..d8f22b682d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -36,13 +36,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); + } #else ignore = karg; ignore = b2c_map; @@ -68,8 +71,8 @@ template {}, + Number{}, I1, - Number{})); + Number{})); } // return block_id to C matrix tile idx (m0, n0, k_split) mapping @@ -848,8 +855,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerXDL, - NPerXDL, + MPerXdl, + NPerXdl, MRepeat, NRepeat, K1, @@ -899,8 +906,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // output: register to global memory { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -932,11 +939,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 M1, M2, M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_freeze_transform(I0), // freeze nblock make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -1007,9 +1014,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, 1, - CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -1028,11 +1035,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8822778b52..7d5a8da60f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,21 +46,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -230,6 +234,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 c_block_size * sizeof(FloatCShuffle)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index c3bbece33c..7c559d1f85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -49,23 +49,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -235,6 +239,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 c_block_size * sizeof(FloatC)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -298,7 +318,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 const auto NBlock = N / NPerBlock; constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWave = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 2e288efee2..83f8773a08 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,25 +53,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_c1_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } #else ignore = p_a_grid; ignore = p_b_grid; @@ -244,6 +248,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 c_block_size * sizeof(FloatC)); } + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation>(); + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ __device__ static constexpr bool @@ -307,7 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 const auto NBlock = N / NPerBlock; constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t NWave = + NXdlPerWave * NPerXdl == 0 ? 1 : NPerBlock / (NXdlPerWave * NPerXdl); const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index b0a606cf38..b9c0d671db 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -45,24 +45,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -80,26 +83,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -940,6 +946,8 @@ struct GridwiseMoeGemm c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index a8b759da38..5f5e24fb9f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -45,26 +45,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -82,28 +85,31 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -956,6 +962,8 @@ struct GridwiseMoeGemmBlockScale c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 34fcf0e935..9066decc0a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -49,6 +49,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -68,6 +70,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -87,27 +90,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1167,6 +1173,8 @@ struct GridwiseMoeGemmMX } } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 3a7b35683d..6854e64092 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -48,25 +48,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -86,6 +89,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -107,6 +112,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1099,6 +1105,8 @@ struct GridwiseMoeGemmMXBNS } } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 3c4f7a24c7..c367079aab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -48,25 +48,28 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -85,27 +88,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared_0, - p_shared_1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + } #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1059,6 +1065,8 @@ struct GridwiseMoeGemmMX_BPreshuffle c_block_size * sizeof(CShuffleDataType)); } + IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) {