[CK_TILE] Split-K autodeduction (#3351)

* First version of split-K autodeduction.

* Fix circular dependency and kernel construction.

* Fix tolerance calculation for bwd weight example.

* Simplify kernel construction.

* Fix kernel launching bug for split-K autodeduce.

* Add split-K autodeduction support for the two stage example.

* Fix a corner case.

* Fix clang-format.

* Fix clang-format for inc files.

* Add missing header.

* Prevent too large split-K values.

* Fix formatting.

* Add unit tests for IsSupportedArgument in grouped bwd conv.

* clang-format.

* Fix merge conflicts.

* Address feedback from code review.

* clang-format

* Fix new tests after merge.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-12-10 09:30:30 +02:00
committed by GitHub
parent 1aa93ef551
commit fc22320d78
11 changed files with 485 additions and 51 deletions

View File

@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
@@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker
}
auto preprocess = [&]() {
if(args.k_batch > 1)
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr,
@@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker
}
};
return ck_tile::launch_kernel_time_mask(
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
};
if(args.k_batch == 1)

View File

@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
using WorkspaceDataType = float;
@@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto kargs = Kernel::MakeKernelArgs(ws_args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
const auto kargs = Kernel::MakeKernelArgs(ws_args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
@@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
}
auto preprocess = [&]() {
if(args.k_batch > 1)
if(kargs.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
@@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
@@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
};
if(args.k_batch == 1)

View File

@@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[])
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
struct InvokerResult
{
float ave_time;
ck_tile::index_t split_k;
};

View File

@@ -14,22 +14,22 @@ template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
auto res = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return ave_time;
return res;
}
template <ck_tile::index_t NDimSpatial,
@@ -132,16 +132,17 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
auto res = invoke_grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
const float ave_time = res.ave_time;
weight_dev_buf.FromDevice(weight.data());
@@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
const ck_tile::index_t split_k = res.split_k;
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
GemmK, split_k, max_accumulated_value);
pass = ck_tile::check_err(weight,
weight_host_ref,
"Error: Incorrect results!",