[CK_TILE] Add permuteN optimization to remove lds operation in c_shuffle (#2764)

* permuteN optimization to remove lds operation in c_shuffle

* add the change log

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
lalala-sh
2025-09-09 13:02:48 +08:00
committed by GitHub
parent 92b07380d3
commit 75570d0fa8
5 changed files with 189 additions and 4 deletions

View File

@@ -31,7 +31,8 @@ template <typename ADataType_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -89,10 +91,13 @@ struct CShuffleEpilogue
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
@@ -367,11 +372,152 @@ struct CShuffleEpilogue
struct EmptyScale
{
};
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /*p_smem*/,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
// Tiles to hold row/col scales when present
using SMType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
using SNType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_m, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_n, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
// Slice accumulators for this M repeat into the permuted layout
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If scales provided, load them with identical distribution
if constexpr(has_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
}
// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t base = n_idx * c_warp_y_lengths.product();
// local lambda to fuse scale (if present) and convert
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
if constexpr(has_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
v = static_cast<AccDataType>(v * s_m * s_n);
}
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
};
// Your current packing pattern (rows 0..3, spaced by NRepeat)
emit(n_idx + 0 * NRepeat, 0);
emit(n_idx + 1 * NRepeat, 1);
emit(n_idx + 2 * NRepeat, 2);
emit(n_idx + 3 * NRepeat, 3);
});
// store/update
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,