From f728087c61604b0c76dddf557c8e0bc0da97eebb Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 25 Dec 2024 23:26:17 +0800 Subject: [PATCH] Modify the a_thread offset since the A data load is different from B. --- .../grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 356113733b..a806003297 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -1368,8 +1368,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 make_tuple(Number{}, Number{})); constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl; + // auto a_thread_offset = + // get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl; + + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 128) * MPerXdl; auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2