From f5345174e445d67ec694f93dee7908d5a7b04d3e Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 23 Jun 2025 14:02:03 +0000 Subject: [PATCH] tmp --- .../grouped_convolution_forward.cpp | 2 +- .../grouped_convolution_utils.hpp | 3 +- .../run_grouped_convolution_fwd_example.inc | 4 +- ...ped_convolution_backward_weight_kernel.hpp | 64 ++++++++++++++++--- .../grouped_convolution_forward_kernel.hpp | 11 ++-- .../utils/grouped_convolution_utils.hpp | 39 +++-------- .../transform_conv_bwd_weight_to_gemm.hpp | 22 ++++--- 7 files changed, 89 insertions(+), 56 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index bcaea024c7..258851ef57 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -23,7 +23,7 @@ template , typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) +float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) { constexpr int kBlockPerCu = 1; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 84a29771e8..3a4cad1ad0 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -105,4 +105,5 @@ auto create_args(int argc, char* argv[]) } // host API -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s); +using GroupedConvHostArgs = ck_tile::GroupedConvHostArgs; +float grouped_conv_fwd(const GroupedConvHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc index ed72eb354d..e3b50b85f5 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc @@ -32,7 +32,7 @@ template -float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat) +float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, int n_warmup, int n_repeat) { float ave_time = grouped_conv_fwd args(conv_param, input_dev_buf.GetDeviceBuffer(), weight_dev_buf.GetDeviceBuffer(), {}, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 446df6e146..b990b648df 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -435,6 +435,47 @@ struct GroupedConvolutionBackwardWeightKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1); + + if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + // not supported + } + + if constexpr(std::is_same_v) + { + // not supported + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { @@ -566,7 +607,8 @@ struct GroupedConvolutionBackwardWeightKernel const InDataType* b_ptr, const std::array& ds_ptr, WeiDataType* c_ptr, - const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const SplitKBatchOffset& splitk_batch_offset) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); @@ -700,6 +742,7 @@ struct GroupedConvolutionBackwardWeightKernel WeiDataType* c_ptr, void* smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { @@ -711,8 +754,8 @@ struct GroupedConvolutionBackwardWeightKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -751,6 +794,7 @@ struct GroupedConvolutionBackwardWeightKernel void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { @@ -761,8 +805,8 @@ struct GroupedConvolutionBackwardWeightKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -787,6 +831,8 @@ struct GroupedConvolutionBackwardWeightKernel const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const SplitKBatchOffset splitk_batch_offset(kargs); + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); @@ -794,8 +840,8 @@ struct GroupedConvolutionBackwardWeightKernel // options // conv_bwd_weight = Out * In = Weight - const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; - const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; + const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a + splitk_batch_offset.a_k_split_offset; + const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b + splitk_batch_offset.b_k_split_offset; WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; // allocate LDS @@ -809,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel is_any_of::value)) { RunGemm2LDS( - a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n); + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } } else @@ -818,7 +864,7 @@ struct GroupedConvolutionBackwardWeightKernel EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n); + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } } } diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 8b73656771..29d871211f 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -34,7 +34,7 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), static_cast(args.N_), @@ -103,7 +103,7 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), static_cast(args.N_), @@ -179,7 +179,7 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), static_cast(args.N_), @@ -366,6 +366,7 @@ struct GroupedConvolutionForwardKernel using OutDataType = remove_cvref_t; using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; + using GroupedConvFwdHostArgs = GroupedConvHostArgs; // TODO: Enable this static constexpr bool IsSplitKSupported = false; @@ -388,7 +389,7 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args) + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdHostArgs& args) { const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(), @@ -401,7 +402,7 @@ struct GroupedConvolutionForwardKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized - MakeKernelArgs(const GroupedConvHostArgs& hostArgs) + MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) { return GroupedConvFwdKernelArgsSpecialized(hostArgs); } diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 81670b453b..698d58ceb2 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -14,14 +14,15 @@ namespace ck_tile { /// This structure is passed to Grouped Convolution Kernels when creating kernel /// arguments object. It contain all necessary information required to /// build proper kernel argument and launch kernel on GPU. +template struct GroupedConvHostArgs : public conv::ConvParam { CK_TILE_HOST GroupedConvHostArgs() = delete; CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, - const void* in_ptr_, - const void* wei_ptr_, + InPtr in_ptr_, + WeiPtr wei_ptr_, const std::vector ds_ptr_, - void* out_ptr_, + OutPtr out_ptr_, index_t k_batch_) : conv::ConvParam(conv_param), in_ptr(in_ptr_), @@ -32,37 +33,15 @@ struct GroupedConvHostArgs : public conv::ConvParam { } - const void* in_ptr; - const void* wei_ptr; + InPtr in_ptr; + WeiPtr wei_ptr; const std::vector ds_ptr; - void* out_ptr; + OutPtr out_ptr; index_t k_batch; }; -struct GroupedConvBwdWeightHostArgs : public conv::ConvParam -{ - CK_TILE_HOST GroupedConvBwdWeightHostArgs() = delete; - CK_TILE_HOST GroupedConvBwdWeightHostArgs(ConvParam conv_param, - const void* in_ptr_, - void* wei_ptr_, - const std::vector ds_ptr_, - const void* out_ptr_, - index_t k_batch_) - : conv::ConvParam(conv_param), - in_ptr(in_ptr_), - wei_ptr(wei_ptr_), - ds_ptr(ds_ptr_), - out_ptr(out_ptr_), - k_batch(k_batch_) - { - } - - const void* in_ptr; - void* wei_ptr; - const std::vector ds_ptr; - const void* out_ptr; - index_t k_batch; -}; +using GroupedConvFwdHostArgs = GroupedConvHostArgs; +using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs; template ::type = false> - CK_TILE_HOST auto make_out_grid_desc() const + CK_TILE_HOST auto make_out_grid_desc(const index_t GemmKBatch) const { // NWGK const index_t NDoHoWoStride = G_ * K_; @@ -423,7 +423,7 @@ struct TransformConvBwdWeightToGemm // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), + return make_naive_tensor_descriptor(make_tuple(N_ * Wo / GemmKBatch_, K_), make_tuple(NDoHoWoStride, KStride)); } @@ -538,23 +538,22 @@ struct TransformConvBwdWeightToGemm // properties template ::type = false> - CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t GemmKBatch) const { // Assume NumGroupsToMerge == 1 for now - const index_t GemmKTotal = N_ * Wo_; + const index_t GemmKTotal = N_ * Wo_ / KBatch; // tmp const index_t GemmM = K_ * NumGroupsToMerge; const index_t GemmN = C_ * X_ * NumGroupsToMerge; const auto PadGemmM = MPerBlock - GemmM % MPerBlock; const auto PadGemmN = NPerBlock - GemmN % NPerBlock; - const index_t GemmKBatch = 1; const index_t GemmK0 = integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const auto out_grid_desc = make_out_grid_desc(); - const auto in_grid_desc = make_in_grid_desc(); + const auto out_grid_desc = make_out_grid_desc(GemmKBatch); + const auto in_grid_desc = make_in_grid_desc(GemmKBatch); const auto wei_grid_desc = make_wei_grid_desc(); // A: output tensor comes in K_M @@ -597,6 +596,13 @@ struct TransformConvBwdWeightToGemm make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -604,7 +610,7 @@ struct TransformConvBwdWeightToGemm make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_gemmkpad_gemmm_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkpad_gemmn_grid_desc, wei_gemmm_gemmn_pad_grid_desc); }