mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] Add permuteN optimization to remove lds operation in c_shuffle (#2764)
* permuteN optimization to remove lds operation in c_shuffle * add the change log --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -31,7 +31,8 @@ template <typename ADataType_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -89,10 +91,13 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
@@ -367,11 +372,152 @@ struct CShuffleEpilogue
|
||||
struct EmptyScale
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale>
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /*p_smem*/,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType =
|
||||
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
|
||||
using SNType =
|
||||
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
// Slice accumulators for this M repeat into the permuted layout
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If scales provided, load them with identical distribution
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
}
|
||||
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t base = n_idx * c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
|
||||
v = static_cast<AccDataType>(v * s_m * s_n);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
|
||||
};
|
||||
|
||||
// Your current packing pattern (rows 0..3, spaced by NRepeat)
|
||||
emit(n_idx + 0 * NRepeat, 0);
|
||||
emit(n_idx + 1 * NRepeat, 1);
|
||||
emit(n_idx + 2 * NRepeat, 2);
|
||||
emit(n_idx + 3 * NRepeat, 3);
|
||||
});
|
||||
|
||||
// store/update
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
|
||||
Reference in New Issue
Block a user