[CK_TILE] Split-K autodeduction (#3351)

* First version of split-K autodeduction.

* Fix circular dependency and kernel construction.

* Fix tolerance calculation for bwd weight example.

* Simplify kernel construction.

* Fix kernel launching bug for split-K autodeduce.

* Add split-K autodeduction support for the two stage example.

* Fix a corner case.

* Fix clang-format.

* Fix clang-format for inc files.

* Add missing header.

* Prevent too large split-K values.

* Fix formatting.

* Add unit tests for IsSupportedArgument in grouped bwd conv.

* clang-format.

* Fix merge conflicts.

* Address feedback from code review.

* clang-format

* Fix new tests after merge.

---------

Co-authored-by: Ville Pietilä <>

[ROCm/composable_kernel commit: fc22320d78]
This commit is contained in:
Ville Pietilä
2025-12-10 09:30:30 +02:00
committed by GitHub
parent 822da5d3a7
commit d719c09343
11 changed files with 485 additions and 51 deletions

View File

@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
@@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker
}
auto preprocess = [&]() {
if(args.k_batch > 1)
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr,
@@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker
}
};
return ck_tile::launch_kernel_time_mask(
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
};
if(args.k_batch == 1)

View File

@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
using WorkspaceDataType = float;
@@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto kargs = Kernel::MakeKernelArgs(ws_args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
const auto kargs = Kernel::MakeKernelArgs(ws_args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
@@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
}
auto preprocess = [&]() {
if(args.k_batch > 1)
if(kargs.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
@@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
@@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
};
if(args.k_batch == 1)

View File

@@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[])
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
struct InvokerResult
{
float ave_time;
ck_tile::index_t split_k;
};

View File

@@ -14,22 +14,22 @@ template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
auto res = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return ave_time;
return res;
}
template <ck_tile::index_t NDimSpatial,
@@ -132,16 +132,17 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
auto res = invoke_grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
const float ave_time = res.ave_time;
weight_dev_buf.FromDevice(weight.data());
@@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
const ck_tile::index_t split_k = res.split_k;
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
GemmK, split_k, max_accumulated_value);
pass = ck_tile::check_err(weight,
weight_host_ref,
"Error: Incorrect results!",

View File

@@ -70,6 +70,24 @@ inline bool is_load_tr_supported()
// Check if load transpose is supported.
return get_device_name() == "gfx950";
}
inline size_t get_num_cus()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return 0;
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return 0;
}
return static_cast<size_t>(props.multiProcessorCount);
}
} // namespace ck_tile
#endif

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"

View File

@@ -14,6 +14,8 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
@@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
static constexpr bool IsSplitKSupported = true;
static constexpr auto I0 = number<0>();
@@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
}
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
kernel_args);
kernel_args.k_batch = optimal_split_k;
}
return kernel_args;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if(kargs.k_batch < 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
}
return false;
}
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
{
if(kargs.k_batch == 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
}
return false;
}
}
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
{
// The epilogue performs atomic add related to split-K using the ODataType.
// If the type is less accurate than float, large split-K values may lead to
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
if(kargs.k_batch > 128)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"For epilogue output data type that is not float/double, we must have "
"k_batch <= 128.");
}
return false;
}
}
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
CK_TILE_ERROR("Conditions not met for K_batch > 1!");
}
return false;
}

View File

@@ -0,0 +1,81 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <numeric>
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/kernel_launch.hpp"
namespace ck_tile {
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
CK_TILE_HOST index_t get_max_occupancy_for_kernel()
{
constexpr int dynamic_smem_size = 0;
constexpr int min_blocks_per_cu = 1;
const auto kernel_ptr = kentry<min_blocks_per_cu, KernelImpl, KernelArgs>;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size));
return static_cast<index_t>(max_occupancy);
}
CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size)
{
static const index_t num_cus = get_num_cus();
const index_t max_capacity = max_occupancy * num_cus;
index_t k_batch = 1;
const auto optimal_split = static_cast<index_t>(std::floor((1.0 * max_capacity) / grid_size));
if(optimal_split > 1)
{
k_batch = optimal_split;
}
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
<< max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
}
return k_batch;
}
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
struct ActiveWorkgroupsPerCU
{
CK_TILE_HOST ActiveWorkgroupsPerCU()
{
max_occupancy_ = get_max_occupancy_for_kernel<BlockSize, KernelArgs, KernelImpl>();
}
index_t max_occupancy_{1};
};
template <index_t BlockSize, typename KernelImpl, typename TilePartitioner, typename KernelArgs>
CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs)
{
static ActiveWorkgroupsPerCU<BlockSize, KernelArgs, KernelImpl> active_workgroups_per_cu;
const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch;
auto optimal_k_batch =
get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size);
const auto max_allowed_k_batch = kargs.GemmK;
optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl;
}
return optimal_k_batch;
}
} // namespace ck_tile

View File

@@ -38,3 +38,4 @@ add_subdirectory(atomic_add_op)
add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)
add_subdirectory(grouped_conv)

View File

@@ -0,0 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp)
endif()

View File

@@ -0,0 +1,249 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
using namespace ck_tile;
struct TestConvConfig
{
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 8;
static constexpr index_t VectorSizeC = 8;
static constexpr index_t M_Tile = 128;
static constexpr index_t N_Tile = 128;
static constexpr index_t K_Tile = 32;
static constexpr index_t M_Warp = 2;
static constexpr index_t N_Warp = 2;
static constexpr index_t K_Warp = 1;
static constexpr index_t M_Warp_Tile = 16;
static constexpr index_t N_Warp_Tile = 16;
static constexpr index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t NumGroupsToMerge = 1;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
};
// Helper to build full kernel type
template <typename PrecType,
typename ConvConfig,
typename InLayout,
typename WeiLayout,
typename OutLayout,
memory_operation_enum MemOp = memory_operation_enum::set,
index_t NDimSpatial = 2>
struct BuildKernel
{
using GemmShape = TileGemmShape<
sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>;
using ConvTraits = GroupedConvTraits<NDimSpatial,
ConvolutionSpecialization::Default,
InLayout,
WeiLayout,
tuple<>,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<GemmShape, 8, 4>;
using GemmUniversalTraits =
TileGemmUniversalTraits<ConvTraits::FixedGemmParams::kPadM,
ConvTraits::FixedGemmParams::kPadN,
ConvTraits::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename ConvTraits::AsLayoutBwdWeight,
typename ConvTraits::BsLayoutBwdWeight,
typename ConvTraits::CLayoutBwdWeight,
ConvTraits::FixedGemmParams::TransposeC,
ConvTraits::FixedGemmParams::UseStructuredSparsity,
ConvTraits::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem =
GemmPipelineProblem<PrecType, // OutDataType (A in bwd weight)
PrecType, // InDataType (B in bwd weight)
float, // AccDataType
GemmShape,
typename ConvTraits::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
element_wise::PassThrough,
element_wise::PassThrough,
PrecType, // WeiDataType (C in bwd weight)
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeA,
ConvTraits::VectorSizeB>;
using UniversalGemmProblem =
UniversalGemmPipelineProblem<PrecType,
PrecType,
float,
GemmShape,
GemmUniversalTraits,
ConvConfig::Scheduler,
element_wise::PassThrough,
element_wise::PassThrough,
PrecType,
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeA,
ConvTraits::VectorSizeB>;
using GemmPipeline = GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using EpilogueProblem = CShuffleEpilogueProblem<PrecType,
PrecType,
tuple<>,
float,
PrecType,
typename ConvTraits::ImplicitGemmDsLayout,
typename ConvTraits::FixedGemmParams::ELayout,
element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvTraits::FixedGemmParams::TransposeC,
MemOp,
ConvConfig::NumWaveGroups,
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeC>;
using Epilogue = CShuffleEpilogue<EpilogueProblem>;
using type =
GroupedConvolutionBackwardWeightKernel<ConvTraits, TilePartitioner, GemmPipeline, Epilogue>;
};
// Helper to create 2D host args
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t G,
index_t N,
index_t K,
index_t C,
index_t Y,
index_t X,
index_t Hi,
index_t Wi,
index_t stride_y,
index_t stride_x,
index_t dilation_y,
index_t dilation_x,
index_t left_pad_y,
index_t left_pad_x,
index_t right_pad_y,
index_t right_pad_x,
index_t k_batch = 1)
{
auto conv_param = conv::ConvParam{2,
G,
N,
K,
C,
{Y, X},
{Hi, Wi},
{stride_y, stride_x},
{dilation_y, dilation_x},
{left_pad_y, left_pad_x},
{right_pad_y, right_pad_x}};
return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch};
}
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
{
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
}
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
{
};
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, ValidKBatch)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK>::type;
auto host_args_kbatch_1 = create_2d_host_args(1);
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_1));
auto host_args_kbatch_4 = create_2d_host_args(4);
auto kargs_4 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_4);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_4));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK>::type;
auto host_args_kbatch_0 = create_2d_host_args(0);
auto kargs = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_0);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK,
memory_operation_enum::atomic_add>::type;
// k_batch = 1 should fail with atomic_add
auto host_args_kbatch_1 = create_2d_host_args(1);
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1));
// k_batch = 2 should pass
auto host_args_kbatch_2 = create_2d_host_args(2);
auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_128 = create_2d_host_args(128);
auto kargs_128 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_129 = create_2d_host_args(129);
auto kargs_129 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));
}