mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Automatic deduction of split-K value for grouped convolution (#2491)
* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Use simple best occupancy model to calculate the split-K. * Handle split-K autodeduction in explicit gemm conv. * Add unit tests for split-K autodeduction. * Remove oversubscription. * Small fixes. * Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle. * Run clang formatting. * Fix error handling in the conv profiler. * Add missing documentation for the autodeducted split-K values. * Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver. * Fix clang formatting and split-K profiler documentation. * Rename max_occupancy value variable. * Calculate grid size for split-K autodeduction directly from input array shapes and template params. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
|
||||
@@ -40,7 +41,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k)
|
||||
const std::string& split_k)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -138,10 +139,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
ck::index_t best_split_k = 1;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
std::string best_split_k("1");
|
||||
|
||||
// profile device Conv instances
|
||||
bool all_pass = true;
|
||||
@@ -170,11 +171,20 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
std::vector<ck::index_t> split_k_list = {/*auto deduce value*/ -1, 1, 2, 4, 8, 16, 32, 64, 128};
|
||||
|
||||
if(split_k > 0)
|
||||
if(split_k != "all")
|
||||
{
|
||||
split_k_list = {split_k};
|
||||
try
|
||||
{
|
||||
ck::index_t split_k_value = std::stoi(split_k);
|
||||
split_k_list = {split_k_value};
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << e.what() << '\n';
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
@@ -200,6 +210,16 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
out_element_op,
|
||||
split_k_list[split_k_id]);
|
||||
|
||||
auto split_k_value = split_k_list[split_k_id];
|
||||
auto split_k_param_str = std::to_string(split_k_value);
|
||||
auto* split_k_arg =
|
||||
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
|
||||
if(split_k_arg && split_k_value < 0)
|
||||
{
|
||||
split_k_value = split_k_arg->k_batch_;
|
||||
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
|
||||
}
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
@@ -222,7 +242,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
|
||||
<< split_k_list[split_k_id] << std::endl;
|
||||
<< split_k_param_str << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
@@ -230,7 +250,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_split_k = split_k_list[split_k_id];
|
||||
best_split_k = split_k_param_str;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
@@ -244,7 +264,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_list[split_k_id];
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
|
||||
Reference in New Issue
Block a user