mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK Tile] CShuffle Tile Permute N all warp compatible (#2966)
* solve the hard_code issue of kM2 * clang format
This commit is contained in:
@@ -433,8 +433,13 @@ struct CShuffleEpilogue
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
@@ -515,32 +520,25 @@ struct CShuffleEpilogue
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t base = n_idx * c_warp_y_lengths.product();
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
|
||||
v = static_cast<AccDataType>(v * s_m * s_n);
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
|
||||
};
|
||||
|
||||
// Your current packing pattern (rows 0..3, spaced by NRepeat)
|
||||
emit(n_idx + 0 * NRepeat, 0);
|
||||
emit(n_idx + 1 * NRepeat, 1);
|
||||
emit(n_idx + 2 * NRepeat, 2);
|
||||
emit(n_idx + 3 * NRepeat, 3);
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
|
||||
Reference in New Issue
Block a user