From fe7e2b3de9b44a6f460a02a88948fc209dad585a Mon Sep 17 00:00:00 2001 From: solin Date: Tue, 22 Apr 2025 02:27:17 +0000 Subject: [PATCH] add xor to optimize bank conflict --- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 584a3240cc..f73b0e3cc6 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -23,7 +23,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = GetSmemPackA(); - +#if 1 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -36,7 +36,49 @@ struct UniversalFlatmmPipelineAgBgCrPolicy make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); +#endif +/*xor*/ +#if 0 + using ADataType = remove_cvref_t; + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (16 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (16 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif return a_lds_block_desc; }