add transform grid

This commit is contained in:
Jakub Piasecki
2026-01-28 16:35:06 +00:00
parent eb3eacebce
commit a9fcb27ded
2 changed files with 95 additions and 37 deletions

View File

@@ -118,43 +118,101 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const auto num_k_per_block =
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch;
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
} else {
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
#if defined(__gfx950__)
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
} else {
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
}
#endif
}
else
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
} else {
GridwiseGemm::template Run<false, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
p_ds_grid_grp,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
gemm_kernel_args[group_id].block_2_ctile_map_,
GridwiseGemm::template TransformGrid<decltype(gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_),
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_m_n_,
gemm_kernel_args[group_id].e_grid_desc_m_n_,
KBatch,
k_idx);
}
}
#else
ignore = karg;

View File

@@ -272,7 +272,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
__host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc)
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
{
if constexpr(!DirectLoad)