Integrate universal gemm with conv bwd data and add SplitK (#1315)

* Integrate universal gemm with conv bwd data

* Fix multi d kernel

* Add splitK support

* instances refactor

* instances refactor

* refactor

* fixeS

* fixes

* 16x16 instnaces

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Fixes

* fix

* fix
This commit is contained in:
Bartłomiej Kocot
2025-04-28 23:54:49 +02:00
committed by GitHub
parent d9786f3363
commit 4094ad158a
69 changed files with 2262 additions and 349 deletions

View File

@@ -34,7 +34,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k = 1)
{
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -88,6 +89,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
// reset input to zero
in_device_buf.SetZero();
float max_accumulated_value = 0;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
@@ -114,17 +116,19 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
in_element_op);
ref_invoker.Run(ref_argument);
max_accumulated_value = *std::max_element(in_host.mData.begin(), in_host.mData.end());
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
// profile device op instances
bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
@@ -150,7 +154,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
<< gb_per_sec << " GB/s, " << op_name << ", SplitK " << split_k_for_run
<< std::endl;
if(tflops > best_tflops)
{
@@ -158,13 +163,39 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_for_run;
}
if(do_verification)
{
in_device_buf.FromDevice(in_device.mData.data());
pass = pass & ck::utils::check_err(in_device, in_host);
using ComputeType = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
OutDataType,
WeiDataType>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = conv_param.K_;
// Calculate thresholds
auto rtol = ck::utils::get_relative_threshold<ComputeType, InDataType, AccDataType>(
num_accums / split_k_for_run);
auto atol = ck::utils::get_absolute_threshold<ComputeType, InDataType, AccDataType>(
max_accumulated_value / split_k_for_run, num_accums / split_k_for_run);
// Calculate error due to split_k accumulation
auto rtol_split_k =
ck::utils::get_relative_threshold<InDataType, InDataType, InDataType>(
split_k_for_run);
auto atol_split_k =
ck::utils::get_absolute_threshold<InDataType, InDataType, InDataType>(
max_accumulated_value, split_k_for_run);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
pass = pass & ck::utils::check_err(
in_device, in_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
if(do_log)
{
@@ -225,35 +256,47 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
if(split_k > 0)
{
split_k_list = {split_k};
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
{},
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
out_lengths,
out_strides,
wei_lengths,
wei_strides,
{},
{},
in_lengths,
in_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
out_element_op,
wei_element_op,
in_element_op);
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
{},
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
out_lengths,
out_strides,
wei_lengths,
wei_strides,
{},
{},
in_lengths,
in_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
out_element_op,
wei_element_op,
in_element_op,
split_k_list[split_k_id]);
run_impl(op_ptr, argument_ptr);
run_impl(op_ptr, argument_ptr, split_k_list[split_k_id]);
}
}
std::cout << "Best configuration parameters:"
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
<< best_split_k << std::endl;
return pass;
}