diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index ac5b7dd0c4..528a242344 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -728,7 +728,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument& epilogue_args, - const index_t k_id = 0) + const index_t A_k_id = 0, + const index_t B_k_id = 0) { const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); @@ -798,7 +799,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale a_scale_struct, b_scale_struct, epilogue_args, - k_id); + A_k_id, + B_k_id); } // NOTE: Wrapper function to have __global__ function in common @@ -811,7 +813,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale const SplitKBatchOffset& splitk_batch_offset, Argument& karg, EpilogueArgument& epilogue_args, - const index_t k_id = 0) + const index_t A_k_id = 0, + const index_t B_k_id = 0) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -862,7 +865,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale karg.b_element_op, karg.cde_element_op, epilogue_args, - k_id); + A_k_id, + B_k_id); } };