[CK_TILE] Split-K autodeduction (#3351)

* First version of split-K autodeduction.

* Fix circular dependency and kernel construction.

* Fix tolerance calculation for bwd weight example.

* Simplify kernel construction.

* Fix kernel launching bug for split-K autodeduce.

* Add split-K autodeduction support for the two stage example.

* Fix a corner case.

* Fix clang-format.

* Fix clang-format for inc files.

* Add missing header.

* Prevent too large split-K values.

* Fix formatting.

* Add unit tests for IsSupportedArgument in grouped bwd conv.

* clang-format.

* Fix merge conflicts.

* Address feedback from code review.

* clang-format

* Fix new tests after merge.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-12-10 09:30:30 +02:00
committed by GitHub
parent 1aa93ef551
commit fc22320d78
11 changed files with 485 additions and 51 deletions

View File

@@ -14,6 +14,8 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
@@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
<< std::endl;
}
}
@@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
static constexpr bool IsSplitKSupported = true;
static constexpr auto I0 = number<0>();
@@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
}
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
kernel_args);
kernel_args.k_batch = optimal_split_k;
}
return kernel_args;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if(kargs.k_batch < 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
}
return false;
}
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
{
if(kargs.k_batch == 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
}
return false;
}
}
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
{
// The epilogue performs atomic add related to split-K using the ODataType.
// If the type is less accurate than float, large split-K values may lead to
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
if(kargs.k_batch > 128)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"For epilogue output data type that is not float/double, we must have "
"k_batch <= 128.");
}
return false;
}
}
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
CK_TILE_ERROR("Conditions not met for K_batch > 1!");
}
return false;
}