From a3698dab8d73c9468e5821c2b239b003c64f709c Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 3 Oct 2025 17:11:09 +0000 Subject: [PATCH] Merge commit 'b4a4aa2b64a7a94ab04126545a3dc4f6d3eba847' into develop --- .../ops/epilogue/cshuffle_epilogue.hpp | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e0a39a5aea..5918ec806b 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -433,8 +433,13 @@ struct CShuffleEpilogue const ScaleM& scale_m = {}, const ScaleN& scale_n = {}) { + static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size(); + + static_assert(MPerXdl % RowsPerLane == 0, + "CShuffle (permuteN): MPerXdl must be divisible by per-lane row count."); + constexpr int kM0 = MWave; - constexpr int kM2 = 4; + constexpr int kM2 = RowsPerLane; constexpr int kM1 = MPerXdl / kM2; constexpr int kN0 = NWave; @@ -515,32 +520,25 @@ struct CShuffleEpilogue // Pack 4 “rows per lane” as you already do static_for<0, NRepeat, 1>{}([&](auto n_idx) { // source indices in shuffle_acc: (n_idx * product(Y) + row) - const index_t base = n_idx * c_warp_y_lengths.product(); + const index_t plane = c_warp_y_lengths.product(); // local lambda to fuse scale (if present) and convert - auto emit = [&](index_t out_idx, index_t src_row) { - AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row]; - + static_for<0, kM2, 1>{}([&](auto m_lane) { + const int src = n_idx * plane + m_lane; // source row in this N-plane + const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output + AccDataType v = shuffle_acc.get_thread_buffer()[src]; if constexpr(has_scalar_scales) { v = static_cast(v * scale_m * scale_n); } else if constexpr(has_scales && !has_scalar_scales) { - // same linear index mapping on the permuted distribution - const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); - const auto s_n = static_cast(sn_tile.get_thread_buffer()[out_idx]); - v = static_cast(v * s_m * s_n); + const auto sm = static_cast(sm_tile.get_thread_buffer()[dst]); + const auto sn = static_cast(sn_tile.get_thread_buffer()[dst]); + v = static_cast(v * sm * sn); } - - c_out_tensor.get_thread_buffer()[out_idx] = type_convert(v); - }; - - // Your current packing pattern (rows 0..3, spaced by NRepeat) - emit(n_idx + 0 * NRepeat, 0); - emit(n_idx + 1 * NRepeat, 1); - emit(n_idx + 2 * NRepeat, 2); - emit(n_idx + 3 * NRepeat, 3); + c_out_tensor.get_thread_buffer()[dst] = type_convert(v); + }); }); // store/update