Automatic deduction of split-K value for grouped convolution (#2491)

* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Use simple best occupancy model to calculate the split-K.

* Handle split-K autodeduction in explicit gemm conv.

* Add unit tests for split-K autodeduction.

* Remove oversubscription.

* Small fixes.

* Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle.

* Run clang formatting.

* Fix error handling in the conv profiler.

* Add missing documentation for the autodeducted split-K values.

* Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.

* Fix clang formatting and split-K profiler documentation.

* Rename max_occupancy value variable.

* Calculate grid size for split-K autodeduction directly from input array shapes and template params.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-07-31 13:08:45 +03:00
committed by GitHub
parent 7b074249f4
commit e962a41638
14 changed files with 544 additions and 72 deletions

View File

@@ -11,6 +11,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
@@ -40,7 +41,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k)
const std::string& split_k)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -138,10 +139,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
std::string best_split_k("1");
// profile device Conv instances
bool all_pass = true;
@@ -170,11 +171,20 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
std::vector<ck::index_t> split_k_list = {/*auto deduce value*/ -1, 1, 2, 4, 8, 16, 32, 64, 128};
if(split_k > 0)
if(split_k != "all")
{
split_k_list = {split_k};
try
{
ck::index_t split_k_value = std::stoi(split_k);
split_k_list = {split_k_value};
}
catch(const std::exception& e)
{
std::cerr << e.what() << '\n';
exit(EXIT_FAILURE);
}
}
for(auto& op_ptr : op_ptrs)
@@ -200,6 +210,16 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
out_element_op,
split_k_list[split_k_id]);
auto split_k_value = split_k_list[split_k_id];
auto split_k_param_str = std::to_string(split_k_value);
auto* split_k_arg =
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
if(split_k_arg && split_k_value < 0)
{
split_k_value = split_k_arg->k_batch_;
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
}
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
@@ -222,7 +242,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
<< split_k_list[split_k_id] << std::endl;
<< split_k_param_str << std::endl;
if(tflops > best_tflops)
{
@@ -230,7 +250,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_list[split_k_id];
best_split_k = split_k_param_str;
}
if(do_verification)
@@ -244,7 +264,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_list[split_k_id];
const index_t num_accums_split_k = split_k_value;
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(