mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Joye/revise wp pipeline (#3493)
* [CK_TILE] unify double and single lds implementation (#108) Unify LDS buffer management API for single and double buffering modes This change consolidates the Local Data Store (LDS) buffer management by: Merging single and double LDS buffer APIs into a unified interface Implementing ping-pong address calculation in pipeline when double LDS is enabled Computing pong buffer addresses dynamically using base address offsets --------- Co-authored-by: joye <joye@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update wp_pipeline * fix a c++17 issue * update for ci errors * fix ci issues * include a header to fix ci errors * fix some rebase issues * update with rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
|
||||
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
|
||||
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
|
||||
// temporarily
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
|
||||
{
|
||||
|
||||
@@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t n,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
@@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
|
||||
ADataType* p_a_lds_pong =
|
||||
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
@@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong,
|
||||
index_t n = 0) const // Default value for non-preshuffle case
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
@@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
bq_dram_block_window_tmp,
|
||||
n,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -584,8 +583,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong,
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
@@ -598,8 +596,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
bq_dram_block_window_tmp,
|
||||
n, // dummy value, won't be used
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user