mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_TILE] Split-K autodeduction (#3351)
* First version of split-K autodeduction. * Fix circular dependency and kernel construction. * Fix tolerance calculation for bwd weight example. * Simplify kernel construction. * Fix kernel launching bug for split-K autodeduce. * Add split-K autodeduction support for the two stage example. * Fix a corner case. * Fix clang-format. * Fix clang-format for inc files. * Add missing header. * Prevent too large split-K values. * Fix formatting. * Add unit tests for IsSupportedArgument in grouped bwd conv. * clang-format. * Fix merge conflicts. * Address feedback from code review. * clang-format * Fix new tests after merge. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr,
|
||||
@@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
|
||||
@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
@@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
@@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
@@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
@@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
|
||||
@@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[])
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
struct InvokerResult
|
||||
{
|
||||
float ave_time;
|
||||
ck_tile::index_t split_k;
|
||||
};
|
||||
|
||||
@@ -14,22 +14,22 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
auto res = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return ave_time;
|
||||
return res;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
@@ -132,16 +132,17 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
auto res = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
const float ave_time = res.ave_time;
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
|
||||
@@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
|
||||
|
||||
const ck_tile::index_t split_k = res.split_k;
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
GemmK, split_k, max_accumulated_value);
|
||||
pass = ck_tile::check_err(weight,
|
||||
weight_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
|
||||
Reference in New Issue
Block a user