mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
add transform grid
This commit is contained in:
@@ -118,43 +118,101 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const auto num_k_per_block =
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch;
|
||||
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
} else {
|
||||
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
#if defined(__gfx950__)
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
} else {
|
||||
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
} else {
|
||||
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_m_n_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
|
||||
@@ -272,7 +272,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
|
||||
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
|
||||
__host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc)
|
||||
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
|
||||
{
|
||||
|
||||
if constexpr(!DirectLoad)
|
||||
|
||||
Reference in New Issue
Block a user