Joye/revise wp pipeline (#3493)

* [CK_TILE] unify double and single lds implementation (#108)

Unify LDS buffer management API for single and double buffering modes

This change consolidates the Local Data Store (LDS) buffer management by:

Merging single and double LDS buffer APIs into a unified interface
Implementing ping-pong address calculation in pipeline when double LDS is enabled
Computing pong buffer addresses dynamically using base address offsets

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update wp_pipeline

* fix a c++17 issue

* update for ci errors

* fix ci issues

* include a header to fix ci errors

* fix some rebase issues

* update with rebase

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
joyeamd
2026-01-06 05:49:26 +08:00
committed by GitHub
parent 1224bc0a82
commit 2b563ad048
13 changed files with 766 additions and 929 deletions

View File

@@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
}
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
// temporarily
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
{

View File

@@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
@@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
ADataType* p_a_lds_pong =
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
@@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong,
index_t n = 0) const // Default value for non-preshuffle case
void* p_smem,
index_t n = 0) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
@@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n,
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -584,8 +583,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong,
void* p_smem,
index_t n = 0) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
@@ -598,8 +596,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}