Fix splitk ab scale

This commit is contained in:
Enrico Degregori
2025-12-15 08:19:21 +00:00
parent e1694a9547
commit b5ccc070a8

View File

@@ -728,7 +728,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
@@ -798,7 +799,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
// NOTE: Wrapper function to have __global__ function in common
@@ -811,7 +813,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -862,7 +865,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
karg.b_element_op,
karg.cde_element_op,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
};