From 8f3b534d299d658ca169e42185b4b1004e153b91 Mon Sep 17 00:00:00 2001 From: mhYang Date: Fri, 21 Mar 2025 15:00:05 +0000 Subject: [PATCH] Fix xor transform dim. --- ...peline_agmem_bgmem_creg_default_policy.hpp | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index b9784c901f..9a9fa5a435 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -27,7 +27,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPack = 8; #if BANK_CONFLICT_K_FIRST -#pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), @@ -42,7 +41,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_K_FIRST -#pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), @@ -57,7 +55,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_MN_FIRST -#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -72,7 +69,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif XOR -#pragma message ("BANK_CONFLICT: XOR") using ADataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(ADataType); @@ -98,9 +94,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); @@ -110,7 +106,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); #endif @@ -126,7 +122,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPack = 8; #if BANK_CONFLICT_K_FIRST -#pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), @@ -141,7 +136,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_K_FIRST -#pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), @@ -156,7 +150,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif PADDING_MN_FIRST -#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -171,7 +164,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); #elif XOR -#pragma message ("BANK_CONFLICT: XOR") using BDataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(BDataType); @@ -197,9 +189,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); @@ -209,7 +201,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); #endif