Add macro option to enable BUFFER_LOAD_LDS

This commit is contained in:
Feng Shijie
2025-09-11 05:16:29 +00:00
parent e7c1c77120
commit f4fdaedf4c
2 changed files with 44 additions and 27 deletions

View File

@@ -184,7 +184,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
#if defined(__gfx950__)
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
// GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A.
// There is no separate DS_WRITE instruction at all.
dswrite_perM = 0;
@@ -658,10 +658,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
using ABlockTile = decltype(load_tile(a_copy_dram_window));
ABlockTile a_block_tile;
#if defined(__gfx950__)
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) {
enum
{
PrefillBeforeGemm = 1,
PrefillAfterGemm = 2,
PrefillAlways = PrefillBeforeGemm | PrefillAfterGemm,
};
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
// global -> lds
async_load_tile(lds_tile_a, dram_tile_a);
if constexpr(prefill_location & PrefillAfterGemm)
async_load_tile(lds_tile_a, dram_tile_a);
};
auto prefill_lds_a_stage2 = [&](auto lds_tile_a) {
// data has been stored in lds, no need more operation.
@@ -669,9 +676,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
"buffer_load_lds don't support element func fot A before mfma");
};
#else
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) {
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
// global -> vgpr
a_block_tile = load_tile(dram_tile_a);
if constexpr(prefill_location & PrefillBeforeGemm)
a_block_tile = load_tile(dram_tile_a);
};
auto prefill_lds_a_stage2 = [&](auto lds_tile_a) {
// vgpr -> lds
@@ -682,7 +690,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// HEAD
// Prefetch A0
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window);
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAlways>{});
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
@@ -725,7 +733,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window);
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAlways>{});
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
@@ -858,6 +866,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
});
// Prefill A(2i+1)
prefill_lds_a_stage2(a_copy_lds_window_pong);
// Prefetch A(2i+2)
prefill_lds_a_stage1(
a_copy_lds_window_ping, a_copy_dram_window, number<PrefillBeforeGemm>{});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -904,19 +918,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
}
});
});
prefill_lds_a_stage1(
a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAfterGemm>{});
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
// Prefill A(2i+1)
prefill_lds_a_stage2(a_copy_lds_window_pong);
// Prefetch A(2i+2)
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -960,6 +971,14 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
b_warp_tensor_ping(nIter)(kIter) = ub.u;
});
});
// Prefill A(2i+2)
prefill_lds_a_stage2(a_copy_lds_window_ping);
// Prefetch A(2i+3)
prefill_lds_a_stage1(
a_copy_lds_window_pong, a_copy_dram_window, number<PrefillBeforeGemm>{});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -1005,15 +1024,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
}
});
});
prefill_lds_a_stage1(
a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAfterGemm>{});
// Prefill A(2i+2)
prefill_lds_a_stage2(a_copy_lds_window_ping);
// Prefetch A(2i+3)
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
@@ -1075,6 +1090,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
}
});
// Prefill A(loopK)
prefill_lds_a_stage2(a_copy_lds_window_pong);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -1122,9 +1140,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
});
// Prefill A(loopK)
prefill_lds_a_stage2(a_copy_lds_window_pong);
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;

View File

@@ -7,6 +7,8 @@
namespace ck_tile {
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 1
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
@@ -21,7 +23,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto
TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
{
#if defined(__gfx950__) //|| defined(__gfx942__)
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
constexpr int DynamicTileOffsetFlag = 0;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
@@ -118,7 +120,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
{
#if defined(__gfx950__) //|| defined(__gfx942__)
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();