diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0891e8c20b..afe43cd1c0 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker typename DsDataType = ck_tile::tuple<>, typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> - static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, - const ck_tile::stream_config& s) + static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) { // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker TilePartitioner, GemmPipeline, ConvEpilogue>; - auto kargs = Kernel::MakeKernelArgs(args); + const auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker } auto preprocess = [&]() { - if(args.k_batch > 1) + if(kargs.k_batch > 1) { ck_tile::hip_check_error( hipMemsetAsync(kargs.wei_ptr, @@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker } }; - return ck_tile::launch_kernel_time_mask( + const auto ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + const auto split_k = kargs.k_batch; + + return InvokerResult{ave_time, split_k}; }; if(args.k_batch == 1) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index 50c0ce4f87..9221746560 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker typename DsDataType = ck_tile::tuple<>, typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> - static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, - const ck_tile::stream_config& s) + static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) { using WorkspaceDataType = float; @@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker sizeof(WorkspaceDataType)); ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto kargs = Kernel::MakeKernelArgs(ws_args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + const auto kargs = Kernel::MakeKernelArgs(ws_args); const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); @@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker } auto preprocess = [&]() { - if(args.k_batch > 1) + if(kargs.k_batch > 1) ck_tile::hip_check_error( hipMemsetAsync(ws_args.wei_ptr, 0, @@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker s.stream_id_)); }; - return ck_tile::launch_kernel_time_mask( + const auto ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), @@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ck_tile::make_tuple(shape[1], 1), // Output Stride input_tensors, static_cast(c_ptr))); + + const auto split_k = kargs.k_batch; + + return InvokerResult{ave_time, split_k}; }; if(args.k_batch == 1) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index b687e0a660..63dd54dcae 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[]) bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } + +struct InvokerResult +{ + float ave_time; + ck_tile::index_t split_k; +}; diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index 2496a1b0d2..b0a140993a 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -14,22 +14,22 @@ template -float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, - int n_warmup, - int n_repeat) +InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, + int n_warmup, + int n_repeat) { - float ave_time = Invoker::template grouped_conv_bwd_weight( + auto res = Invoker::template grouped_conv_bwd_weight( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - return ave_time; + return res; } template (args, n_warmup, n_repeat); + auto res = invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); + const float ave_time = res.ave_time; weight_dev_buf.FromDevice(weight.data()); @@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); const float max_accumulated_value = *std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end()); + + const ck_tile::index_t split_k = res.split_k; const auto rtol_atol = calculate_rtol_atol( - GemmK, kbatch, max_accumulated_value); + GemmK, split_k, max_accumulated_value); pass = ck_tile::check_err(weight, weight_host_ref, "Error: Incorrect results!", diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index 2d7dc7dd18..e95ccfcfb4 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -70,6 +70,24 @@ inline bool is_load_tr_supported() // Check if load transpose is supported. return get_device_name() == "gfx950"; } + +inline size_t get_num_cus() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return 0; + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return 0; + } + return static_cast(props.multiProcessorCount); +} + } // namespace ck_tile #endif diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 1dd13b6246..23a72d79e9 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index f43bfdacac..c9e81d4744 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -14,6 +14,8 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" + #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" #endif @@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs input_left_pads = {static_cast(args.input_left_pads_[0])}; input_right_pads = {static_cast(args.input_right_pads_[0])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs input_right_pads = {static_cast(args.input_right_pads_[0]), static_cast(args.input_right_pads_[1])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.input_right_pads_[1]), static_cast(args.input_right_pads_[2])}; - k_batch = args.k_batch; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs GemmK = a_grid_desc_k_m.get_length(number<0>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + k_batch = args.k_batch; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch + << std::endl; } } @@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel using GroupedConvBwdWeightKernelArgsSpecialized = GroupedConvBwdWeightKernelArgs; - // TODO: Enable this static constexpr bool IsSplitKSupported = true; static constexpr auto I0 = number<0>(); @@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel std::cout << "NPerBlock: " << number{} << std::endl; std::cout << "KPerBlock: " << number{} << std::endl; } - return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + + auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + + using KernelImpl = GroupedConvolutionBackwardWeightKernel; + + // Negative k_batch value: split-K autodeduction. + if(kernel_args.k_batch < 0) + { + const auto optimal_split_k = + calculate_optimal_k_batch( + kernel_args); + kernel_args.k_batch = optimal_split_k; + } + + return kernel_args; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { + if(kargs.k_batch < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "k_batch must be at least one. Ensure argument is created via MakeKernelArgs."); + } + return false; + } + + if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add) + { + if(kargs.k_batch == 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1."); + } + return false; + } + } + + if constexpr(!std::is_same_v && + !std::is_same_v) + { + // The epilogue performs atomic add related to split-K using the ODataType. + // If the type is less accurate than float, large split-K values may lead to + // accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases. + if(kargs.k_batch > 128) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "For epilogue output data type that is not float/double, we must have " + "k_batch <= 128."); + } + return false; + } + } + if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value) || - !IsSplitKSupported) + is_any_of::value)) { if(kargs.k_batch != 1) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + CK_TILE_ERROR("Conditions not met for K_batch > 1!"); } return false; } diff --git a/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp new file mode 100644 index 0000000000..072134dbe7 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp @@ -0,0 +1,81 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include + +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST index_t get_max_occupancy_for_kernel() +{ + constexpr int dynamic_smem_size = 0; + constexpr int min_blocks_per_cu = 1; + + const auto kernel_ptr = kentry; + + int max_occupancy = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size)); + + return static_cast(max_occupancy); +} + +CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size) +{ + static const index_t num_cus = get_num_cus(); + const index_t max_capacity = max_occupancy * num_cus; + + index_t k_batch = 1; + const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); + if(optimal_split > 1) + { + k_batch = optimal_split; + } + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " + << max_occupancy << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl; + } + return k_batch; +} + +template +struct ActiveWorkgroupsPerCU +{ + CK_TILE_HOST ActiveWorkgroupsPerCU() + { + max_occupancy_ = get_max_occupancy_for_kernel(); + } + index_t max_occupancy_{1}; +}; + +template +CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs) +{ + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + + const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch; + auto optimal_k_batch = + get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + const auto max_allowed_k_batch = kargs.GemmK; + optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl; + } + + return optimal_k_batch; +} + +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 6378bb8e43..197c9d6e1d 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -38,3 +38,4 @@ add_subdirectory(atomic_add_op) add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) +add_subdirectory(grouped_conv) diff --git a/test/ck_tile/grouped_conv/CMakeLists.txt b/test/ck_tile/grouped_conv/CMakeLists.txt new file mode 100644 index 0000000000..5bc10ffddd --- /dev/null +++ b/test/ck_tile/grouped_conv/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp) +endif() diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp new file mode 100644 index 0000000000..f37065f7c7 --- /dev/null +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -0,0 +1,249 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gtest/gtest.h" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" + +using namespace ck_tile; + +struct TestConvConfig +{ + static constexpr index_t VectorSizeA = 4; + static constexpr index_t VectorSizeB = 8; + static constexpr index_t VectorSizeC = 8; + + static constexpr index_t M_Tile = 128; + static constexpr index_t N_Tile = 128; + static constexpr index_t K_Tile = 32; + + static constexpr index_t M_Warp = 2; + static constexpr index_t N_Warp = 2; + static constexpr index_t K_Warp = 1; + + static constexpr index_t M_Warp_Tile = 16; + static constexpr index_t N_Warp_Tile = 16; + static constexpr index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3; + static constexpr index_t NumWaveGroups = 1; + static constexpr index_t NumGroupsToMerge = 1; + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; +}; + +// Helper to build full kernel type +template +struct BuildKernel +{ + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + using ConvTraits = GroupedConvTraits, + OutLayout, + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = + TileGemmUniversalTraits; + + using GemmPipelineProblem = + GemmPipelineProblem, + element_wise::PassThrough, + element_wise::PassThrough, + PrecType, // WeiDataType (C in bwd weight) + ConvTraits::FixedGemmParams::FixedVectorSize, + ConvTraits::VectorSizeA, + ConvTraits::VectorSizeB>; + + using UniversalGemmProblem = + UniversalGemmPipelineProblem; + + using GemmPipeline = GemmPipelineAgBgCrCompV3; + + using EpilogueProblem = CShuffleEpilogueProblem, + float, + PrecType, + typename ConvTraits::ImplicitGemmDsLayout, + typename ConvTraits::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + ConvConfig::M_Warp, + ConvConfig::N_Warp, + ConvConfig::M_Warp_Tile, + ConvConfig::N_Warp_Tile, + ConvConfig::K_Warp_Tile, + ConvTraits::FixedGemmParams::TransposeC, + MemOp, + ConvConfig::NumWaveGroups, + ConvTraits::FixedGemmParams::FixedVectorSize, + ConvTraits::VectorSizeC>; + + using Epilogue = CShuffleEpilogue; + + using type = + GroupedConvolutionBackwardWeightKernel; +}; + +// Helper to create 2D host args +static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t G, + index_t N, + index_t K, + index_t C, + index_t Y, + index_t X, + index_t Hi, + index_t Wi, + index_t stride_y, + index_t stride_x, + index_t dilation_y, + index_t dilation_x, + index_t left_pad_y, + index_t left_pad_x, + index_t right_pad_y, + index_t right_pad_x, + index_t k_batch = 1) +{ + auto conv_param = conv::ConvParam{2, + G, + N, + K, + C, + {Y, X}, + {Hi, Wi}, + {stride_y, stride_x}, + {dilation_y, dilation_x}, + {left_pad_y, left_pad_x}, + {right_pad_y, right_pad_x}}; + + return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch}; +} + +static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch) +{ + return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch); +} + +class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test +{ +}; + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, ValidKBatch) +{ + using Kernel = typename BuildKernel::type; + + auto host_args_kbatch_1 = create_2d_host_args(1); + auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_1)); + + auto host_args_kbatch_4 = create_2d_host_args(4); + auto kargs_4 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_4); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_4)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne) +{ + using Kernel = typename BuildKernel::type; + + auto host_args_kbatch_0 = create_2d_host_args(0); + auto kargs = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_0); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne) +{ + using Kernel = typename BuildKernel::type; + + // k_batch = 1 should fail with atomic_add + auto host_args_kbatch_1 = create_2d_host_args(1); + auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1)); + + // k_batch = 2 should pass + auto host_args_kbatch_2 = create_2d_host_args(2); + auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); +} + +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch) +{ + using Kernel = typename BuildKernel::type; + + // k_batch = 128 should pass + auto host_args_kbatch_128 = create_2d_host_args(128); + auto kargs_128 = + typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128)); + + // k_batch = 129 should fail for half_t output + auto host_args_kbatch_129 = create_2d_host_args(129); + auto kargs_129 = + typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129)); +}