From fb7ec3a7aa13fab216ac2d0b7d71701074d84036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 19 Dec 2025 23:58:51 +0100 Subject: [PATCH] Improve XDL to WMMA porting for grouped conv fwd (#3456) Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim. [ROCm/composable_kernel commit: cbc83359649b1b56cd745c4102e9556112f942c2] --- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../builder/test/test_conv_description.cpp | 6 +- .../test/utils/ckb_conv_test_configs.hpp | 2 +- .../gpu/device/device_base.hpp | 56 ++++++--- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 4 +- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 28 ++--- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 4 +- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 4 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 114 ++++++++++-------- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 60 ++++++--- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 73 +++++++---- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 4 +- 13 files changed, 226 insertions(+), 133 deletions(-) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index e3dc261fe3..0d9563e05a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -33,7 +33,7 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_64_64x32x32) - .with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 11c8172533..b30f958bc4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -28,7 +28,7 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_128x128x32) - .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 158cb2668f..dca0e858eb 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -111,8 +111,8 @@ struct DefaultAlgorithm .bk1 = 8, .m_per_xdl = 16, .n_per_xdl = 16, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; + .m_xdl_per_wave = 8, + .n_xdl_per_wave = 8}; ckb::test::TransferABC transfer{ .a = @@ -188,7 +188,7 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) " ├─ Pipeline scheduler: INTRAWAVE\n" " ├─ Warp Gemm parameters: \n" " │ ├─ subtile size: 16×16\n" - " │ └─ Number of warp gemm iterations: 4×4\n" + " │ └─ Number of warp gemm iterations: 8×8\n" " └─ Memory access:\n" " ├─ A Tile transfer: \n" " │ ├─ Tile dimensions: 4×256×8×\n" diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 403c2ffd79..ad5a5f4f6f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -68,7 +68,7 @@ constexpr TransferABC FwdTransfer_4x64x1{ {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .scalar_per_vector = 4}, }, }; diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 9179a279c5..361b116782 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -60,7 +60,7 @@ template -static constexpr auto GetNXdlPerWave2() +static constexpr auto GetXdlPerWave2() { constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32; constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_); @@ -84,17 +84,33 @@ static constexpr auto GetNXdlPerWave2() } } -#define GET_NXDL_PER_WAVE_IMPL \ - template \ - static constexpr auto GetNXdlPerWave() \ - { \ - return GetNXdlPerWave2(); \ +#define GET_NXDL_PER_WAVE_IMPL \ + template \ + static constexpr auto GetNXdlPerWave() \ + { \ + return GetXdlPerWave2(); \ + } + +#define GET_MXDL_PER_WAVE_IMPL \ + template \ + static constexpr auto GetMXdlPerWave() \ + { \ + return GetXdlPerWave2(); \ } template () - : GetNXdlPerWave2(); + ? GetXdlPerWave2() + : GetXdlPerWave2(); if constexpr(IsWave64 == false && NXdlPerWave != 0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 6089e7e63f..b930c50e3a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -190,9 +190,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm(); + GetXdlPerWave2(); static constexpr auto MXdlPerWave32 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 1fc7c8e523..4410871ac1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -235,20 +235,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; - static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2(); - static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2(); + static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2(); + static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2(); static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 9bacb3b661..9ece23985a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -223,9 +223,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle MaskingSpec> { static constexpr auto MXdlPerWave64 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static constexpr auto MXdlPerWave32 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index d6a4f49be8..35b2f54f58 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -211,9 +211,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; static constexpr auto MXdlPerWave64 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static constexpr auto MXdlPerWave32 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 5ed8da8d1b..6229362a7a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -325,9 +325,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + GET_MXDL_PER_WAVE_IMPL + // Force usage of 16x16 instruction for WMMA + static constexpr index_t Wave32MaxMNPerXDL = 16; + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr auto MXdlPerWave32 = + GetMXdlPerWave(); static_assert(NumGroupsToMerge >= 1); @@ -486,35 +492,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ - GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ - EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL_, NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType #define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \ + NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType, DoElementwiseBeforeCShuffle @@ -523,7 +530,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ - MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + MPerXDL, NXdlPerWave, MXdlPerWave_, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ @@ -536,34 +543,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - template + template using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle< CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; - template + template using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle< CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; - template + template using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; #undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS #undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS #undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS - using GridwiseGemm64 = - std::conditional_t, - GridwiseGemmMultipleDBase>; - using GridwiseGemm32 = std::conditional_t, - GridwiseGemmMultipleDBase>; + using GridwiseGemm64 = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABDBase, + GridwiseGemmMultipleDBase>; + using GridwiseGemm32 = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABDBase, + GridwiseGemmMultipleDBase>; using GridwiseGemmCTranspose64 = std::conditional_t, + GridwiseGemmMultipleDCTransposeBase, GridwiseGemm64>; using GridwiseGemmCTranspose32 = std::conditional_t, + GridwiseGemmMultipleDCTransposeBase, GridwiseGemm32>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. @@ -913,14 +921,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if(get_warp_size() == 64) { - if constexpr(NXdlPerWave64 > 0) + if constexpr(MXdlPerWave64 > 0) { InitGridDesc(); } } else { - if constexpr(NXdlPerWave32 > 0) + if constexpr(MXdlPerWave32 > 0) { InitGridDesc(); } @@ -1388,7 +1396,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { if(get_warp_size() == 64) { - if constexpr(NXdlPerWave64 > 0) + if constexpr(MXdlPerWave64 > 0) { return RunImp(arg, stream_config); } @@ -1399,7 +1407,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - if constexpr(NXdlPerWave32 > 0) + if constexpr(MXdlPerWave32 > 0) { return RunImp(arg, stream_config); } @@ -1436,7 +1444,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1720,7 +1731,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // check Gridwise GEMM if(get_warp_size() == 64) { - if constexpr(NXdlPerWave64 > 0) + if constexpr(MXdlPerWave64 > 0) { if constexpr(isMultiA || isMultiB) { @@ -1759,7 +1770,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle else { - if constexpr(NXdlPerWave32 > 0) + if constexpr(MXdlPerWave32 > 0) { if constexpr(isMultiA || isMultiB) { @@ -2047,8 +2058,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" - << "<" + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + + if(get_warp_size() != 64) { + str << "_WmmaPorted"; + } + + str << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index e69a9caa9c..0a4ca23582 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -400,9 +400,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + GET_MXDL_PER_WAVE_IMPL + // Force usage of 16x16 instruction for WMMA + static constexpr index_t Wave32MaxMNPerXDL = 16; + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr auto MXdlPerWave32 = + GetMXdlPerWave(); static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -563,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 : BBlockTransferSrcScalarPerVector; // Use appropriate gridwise gemm - template + template using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, @@ -585,10 +591,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + MPerXDL_, + NPerXDL_, + MXdlPerWave_, + NXdlPerWave*(NPerXDL / NPerXDL_), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -606,7 +612,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -617,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BDataType, DoElementwiseBeforeCShuffle, DirectLoad>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // #undef GridwiseGemmV3TemplateParams @@ -1430,7 +1436,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return avg_time; } - INVOKER_RUN_IMPL + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -1483,7 +1506,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -1758,7 +1784,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(get_warp_size() == 64) { - if constexpr(NXdlPerWave64 > 0) + if constexpr(MXdlPerWave64 > 0) { typename GridwiseGemm64::Argument gemm_arg{nullptr, nullptr, @@ -1780,7 +1806,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } else { - if constexpr(NXdlPerWave32 > 0) + if constexpr(MXdlPerWave32 > 0) { typename GridwiseGemm32::Argument gemm_arg{nullptr, nullptr, @@ -2064,6 +2090,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // clang-format off str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + if(get_warp_size() != 64) { + str << "_WmmaPorted"; + } + if constexpr(DirectLoad) { str << "_DirectLoad"; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 7c121f1482..ac0b4b663d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -206,9 +206,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + GET_MXDL_PER_WAVE_IMPL + // Force usage of 16x16 instruction for WMMA + static constexpr index_t Wave32MaxMNPerXDL = 16; + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr auto MXdlPerWave32 = + GetMXdlPerWave(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; @@ -409,25 +415,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor #define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \ ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL_, \ + NPerXDL_, MXdlPerWave_, NXdlPerWave*(NPerXDL / NPerXDL_), \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle*(NPerXDL / NPerXDL_), \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>; #undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -607,7 +614,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor if(get_warp_size() == 64) { - if constexpr(NXdlPerWave64 > 0) + if constexpr(MXdlPerWave64 > 0) { init_gemm_args(a_grid_ptrs[i], static_cast(p_b), @@ -624,7 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor } else { - if constexpr(NXdlPerWave32 > 0) + if constexpr(MXdlPerWave32 > 0) { init_gemm_args(a_grid_ptrs[i], static_cast(p_b), @@ -769,7 +776,24 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor } } - INVOKER_RUN_IMPL + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 > 0) + { + return RunImp(arg, stream_config); + } + } + else + { + if constexpr(MXdlPerWave32 > 0) + { + return RunImp(arg, stream_config); + } + } + return 0; + } float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -822,7 +846,10 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return false; } } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1205,8 +1232,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" - << "<" + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + if(get_warp_size() != 64) { + str << "_WmmaPorted"; + } + + str << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index c09e526526..b6c2030dee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -206,9 +206,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle MaskingSpec> { static constexpr auto MXdlPerWave64 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static constexpr auto MXdlPerWave32 = - GetNXdlPerWave2(); + GetXdlPerWave2(); static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0");