mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Integrate universal gemm with conv bwd data and add SplitK (#1315)
* Integrate universal gemm with conv bwd data * Fix multi d kernel * Add splitK support * instances refactor * instances refactor * refactor * fixeS * fixes * 16x16 instnaces * Fixes * Fix * Fix * Fix * Fix * Fix * Fixes * fix * fix
This commit is contained in:
@@ -34,7 +34,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k = 1)
|
||||
{
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -88,6 +89,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
|
||||
@@ -114,17 +116,19 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
in_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
max_accumulated_value = *std::max_element(in_host.mData.begin(), in_host.mData.end());
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
ck::index_t best_split_k = 1;
|
||||
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
@@ -150,7 +154,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
<< gb_per_sec << " GB/s, " << op_name << ", SplitK " << split_k_for_run
|
||||
<< std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
@@ -158,13 +163,39 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_split_k = split_k_for_run;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(in_device, in_host);
|
||||
using ComputeType = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
OutDataType,
|
||||
WeiDataType>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = conv_param.K_;
|
||||
// Calculate thresholds
|
||||
auto rtol = ck::utils::get_relative_threshold<ComputeType, InDataType, AccDataType>(
|
||||
num_accums / split_k_for_run);
|
||||
auto atol = ck::utils::get_absolute_threshold<ComputeType, InDataType, AccDataType>(
|
||||
max_accumulated_value / split_k_for_run, num_accums / split_k_for_run);
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<InDataType, InDataType, InDataType>(
|
||||
split_k_for_run);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<InDataType, InDataType, InDataType>(
|
||||
max_accumulated_value, split_k_for_run);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
pass = pass & ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -225,35 +256,47 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
|
||||
if(split_k > 0)
|
||||
{
|
||||
split_k_list = {split_k};
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
{},
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op);
|
||||
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
{},
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op,
|
||||
split_k_list[split_k_id]);
|
||||
|
||||
run_impl(op_ptr, argument_ptr);
|
||||
run_impl(op_ptr, argument_ptr, split_k_list[split_k_id]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best configuration parameters:"
|
||||
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
|
||||
<< best_split_k << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user