diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index f5b3523f60..c504a51ad0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -362,7 +362,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), + make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -374,7 +374,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return a_lds_block_desc; @@ -421,7 +421,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0, number{})), + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -432,7 +432,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(make_merge_transform_v3_division_mod( make_tuple(number{}, number{})), make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; }