mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5114 (commit 59b8cb5)
[CK][CK Tile] Improvements for grouped conv fwd tile profiling (#5114) ## Motivation Improve profiling for grouped convolution forward for better comparison between CK and CK Tile ## Technical Details - Include preprocessing time for ck tile - Add flush cache for conv fwd profiler - Switch configs to builder reflect - Add KPerXdl deduce - Add non-grouped ported instances ## Test Plan test_grouped_convnd_fwd_tile ## Test Result pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-786
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1f2d8166d
commit
2169367735
@@ -1158,26 +1158,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
arg.p_bs_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_flush_cache(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
arg.p_bs_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
arg.p_bs_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1230,26 +1256,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_flush_cache(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1274,26 +1327,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_flush_cache(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
@@ -1049,35 +1048,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
|
||||
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
|
||||
gemm_arg_,
|
||||
stream_config.rotating_count,
|
||||
gemm_arg_.M * gemm_arg_.K * sizeof(ADataType),
|
||||
gemm_arg_.K * gemm_arg_.N * sizeof(BDataType));
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
};
|
||||
|
||||
ave_time += ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel_flush_cache(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -759,19 +759,36 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_flush_cache(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
|
||||
Reference in New Issue
Block a user