From d514c8fca8b3b0e749b8c7ced06debbcb93c9235 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 13:04:31 +0000 Subject: [PATCH] Revert "Reverted unused device impl and updated macros" This reverts commit 845e14d7305c7acb8dda5b54ea0b6d11fafbafa1. --- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 181 ++++++++++++------ ...d_data_wmma_v3_wave_transfer_instances.hpp | 2 +- ..._weight_v3_wmma_wave_transfer_instance.hpp | 2 +- .../grouped_convolution_backward_weight.hpp | 4 +- ...ouped_convolution_backward_weight_wmma.inc | 2 +- 7 files changed, 126 insertions(+), 69 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 6b635b6a23..f4f64f6e77 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -282,7 +282,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); -#ifdef USE_WAVE_TRANSFER_BWD_DATA +#ifdef USE_WAVE_TRANSFER static_assert(UseThreadTileTransfer == false && (ConvBackwardDataSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 88cdef7548..208855149e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -179,7 +179,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeB> { -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE static_assert(UseThreadTileTransfer == false && (ConvBackwardWeightSpecialization == diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index b2ae092c27..e60f95d9cb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -41,8 +41,8 @@ namespace tensor_operation { namespace device { template (); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run() || is_NGCDHW_NGKDHW()) || @@ -293,6 +314,33 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 batch); } + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + template + static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1) + { + const auto grid_desc_n_k = transform_tensor_descriptor( + desc_k0_n_k1, + make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_n_k; + } + using NGCHWTransposeDescType = remove_cvref_t({}, {}))>; @@ -308,9 +356,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ABCGridDescs = decltype(GetABCGridDesc()); - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K_ = remove_cvref_t; + using BGridDesc_N_K_ = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{})); using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; @@ -401,10 +452,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // PermuteA - false, // permuteB - false, // IsBPreshuffle - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + UseThreadTileTransfer>; // ForceThreadTileTransfer // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -434,8 +485,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 &max_occupancy, kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch, true, @@ -473,8 +524,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -572,16 +623,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]); + c_grid_desc_m_n_ = descs[I2]; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -678,8 +729,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -724,17 +775,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0) + << ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl; - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0) + << ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -744,10 +793,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -839,10 +887,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + std::cout << "K0 value is:" + << (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)) + << std::endl; - const auto clear_workspace = [&]() { + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( + arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto clear_workspace = [&]() { hip_check_error( hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; @@ -855,11 +907,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -889,8 +941,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -905,8 +957,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -926,8 +978,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -940,8 +992,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -965,8 +1017,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -979,8 +1031,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1042,10 +1094,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index 6bc0ff8b4f..4c0f9e5276 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -14,7 +14,7 @@ namespace tensor_operation { namespace device { namespace instance { -#define USE_WAVE_TRANSFER_BWD_DATA +#define USE_WAVE_TRANSFER using BF16 = ck::bhalf_t; using F16 = ck::half_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 4984ac1cec..5ca4f4d83d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -1,6 +1,6 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#define USE_WAVE_TRANSFER_BWD_WEI +#define USE_WAVE #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 56f511bc89..c31279d002 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,7 +12,7 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" #endif @@ -962,7 +962,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif -#if defined USE_WAVE_TRANSFER_BWD_WEI +#if defined USE_WAVE #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances(