[CK Tile] CShuffle Tile Permute N all warp compatible (#2966)

* solve the hard_code issue of kM2

* clang format
This commit is contained in:
Thomas Ning
2025-10-03 09:46:13 -07:00
committed by GitHub
parent 4c98535456
commit b4a4aa2b64

View File

@@ -433,8 +433,13 @@ struct CShuffleEpilogue
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
static_assert(MPerXdl % RowsPerLane == 0,
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM2 = RowsPerLane;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
@@ -515,32 +520,25 @@ struct CShuffleEpilogue
// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t base = n_idx * c_warp_y_lengths.product();
const index_t plane = c_warp_y_lengths.product();
// local lambda to fuse scale (if present) and convert
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
static_for<0, kM2, 1>{}([&](auto m_lane) {
const int src = n_idx * plane + m_lane; // source row in this N-plane
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
AccDataType v = shuffle_acc.get_thread_buffer()[src];
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales && !has_scalar_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
v = static_cast<AccDataType>(v * s_m * s_n);
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
v = static_cast<AccDataType>(v * sm * sn);
}
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
};
// Your current packing pattern (rows 0..3, spaced by NRepeat)
emit(n_idx + 0 * NRepeat, 0);
emit(n_idx + 1 * NRepeat, 1);
emit(n_idx + 2 * NRepeat, 2);
emit(n_idx + 3 * NRepeat, 3);
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
});
});
// store/update