Joye/revise wp pipeline (#3493)

* [CK_TILE] unify double and single lds implementation (#108)

Unify LDS buffer management API for single and double buffering modes

This change consolidates the Local Data Store (LDS) buffer management by:

Merging single and double LDS buffer APIs into a unified interface
Implementing ping-pong address calculation in pipeline when double LDS is enabled
Computing pong buffer addresses dynamically using base address offsets

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update wp_pipeline

* fix a c++17 issue

* update for ci errors

* fix ci issues

* include a header to fix ci errors

* fix some rebase issues

* update with rebase

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
joyeamd
2026-01-06 05:49:26 +08:00
committed by GitHub
parent 1224bc0a82
commit 2b563ad048
13 changed files with 766 additions and 929 deletions

View File

@@ -1723,7 +1723,7 @@ struct QuantGemmKernel
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -1735,7 +1735,7 @@ struct QuantGemmKernel
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -1762,7 +1762,7 @@ struct QuantGemmKernel
m = kargs.M;
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m);
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
@@ -1772,7 +1772,7 @@ struct QuantGemmKernel
n = kargs.N;
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
@@ -1788,7 +1788,7 @@ struct QuantGemmKernel
aq_block_window,
bq_block_window,
num_loop,
smem_ptr_0,
smem_ptr,
m,
n);
}
@@ -1796,7 +1796,7 @@ struct QuantGemmKernel
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
a_block_window, b_block_window, num_loop, smem_ptr);
}
}();
@@ -1812,14 +1812,14 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -1828,7 +1828,7 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
else
@@ -1840,14 +1840,14 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -1856,89 +1856,7 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
index_t n = 0;
if constexpr(PreshuffleQuant)
{
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
smem_ptr_0,
smem_ptr_1,
n);
}
else
{
return nullptr;
}
}();
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline with k_batch dispatch
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
}
@@ -1961,37 +1879,10 @@ struct QuantGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemm(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
};

View File

@@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// Only for BQuantGrouped DoubleSmemBuffer is supported
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -377,8 +374,7 @@ struct QuantGroupedGemmKernel
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
void* smem_ptr,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -399,27 +395,22 @@ struct QuantGroupedGemmKernel
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr);
// Run Epilogue Pipeline with split_k dispatch
if(kargs.k_batch == 1)
{
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else
{
auto c_block_window =
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
}
@@ -435,7 +426,7 @@ struct QuantGroupedGemmKernel
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
@@ -449,7 +440,7 @@ struct QuantGroupedGemmKernel
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -481,7 +472,7 @@ struct QuantGroupedGemmKernel
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
smem_ptr);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
@@ -491,13 +482,13 @@ struct QuantGroupedGemmKernel
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr);
}
}();
@@ -510,14 +501,14 @@ struct QuantGroupedGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -526,7 +517,7 @@ struct QuantGroupedGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
else
@@ -538,14 +529,14 @@ struct QuantGroupedGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -554,7 +545,7 @@ struct QuantGroupedGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
}

View File

@@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
}
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
// temporarily
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
{

View File

@@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
@@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
ADataType* p_a_lds_pong =
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
@@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong,
index_t n = 0) const // Default value for non-preshuffle case
void* p_smem,
index_t n = 0) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
@@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n,
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -584,8 +583,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong,
void* p_smem,
index_t n = 0) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
@@ -598,8 +596,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}