mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add macro option to enable BUFFER_LOAD_LDS
This commit is contained in:
@@ -184,7 +184,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
|
||||
// GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A.
|
||||
// There is no separate DS_WRITE instruction at all.
|
||||
dswrite_perM = 0;
|
||||
@@ -658,10 +658,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
using ABlockTile = decltype(load_tile(a_copy_dram_window));
|
||||
ABlockTile a_block_tile;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) {
|
||||
enum
|
||||
{
|
||||
PrefillBeforeGemm = 1,
|
||||
PrefillAfterGemm = 2,
|
||||
PrefillAlways = PrefillBeforeGemm | PrefillAfterGemm,
|
||||
};
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
|
||||
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
|
||||
// global -> lds
|
||||
async_load_tile(lds_tile_a, dram_tile_a);
|
||||
if constexpr(prefill_location & PrefillAfterGemm)
|
||||
async_load_tile(lds_tile_a, dram_tile_a);
|
||||
};
|
||||
auto prefill_lds_a_stage2 = [&](auto lds_tile_a) {
|
||||
// data has been stored in lds, no need more operation.
|
||||
@@ -669,9 +676,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
"buffer_load_lds don't support element func fot A before mfma");
|
||||
};
|
||||
#else
|
||||
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) {
|
||||
auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
|
||||
// global -> vgpr
|
||||
a_block_tile = load_tile(dram_tile_a);
|
||||
if constexpr(prefill_location & PrefillBeforeGemm)
|
||||
a_block_tile = load_tile(dram_tile_a);
|
||||
};
|
||||
auto prefill_lds_a_stage2 = [&](auto lds_tile_a) {
|
||||
// vgpr -> lds
|
||||
@@ -682,7 +690,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window);
|
||||
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAlways>{});
|
||||
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
@@ -725,7 +733,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window);
|
||||
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAlways>{});
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
@@ -858,6 +866,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_pong);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
prefill_lds_a_stage1(
|
||||
a_copy_lds_window_ping, a_copy_dram_window, number<PrefillBeforeGemm>{});
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -904,19 +918,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
});
|
||||
});
|
||||
prefill_lds_a_stage1(
|
||||
a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAfterGemm>{});
|
||||
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_pong);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
@@ -960,6 +971,14 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_ping);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
prefill_lds_a_stage1(
|
||||
a_copy_lds_window_pong, a_copy_dram_window, number<PrefillBeforeGemm>{});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -1005,15 +1024,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
});
|
||||
});
|
||||
prefill_lds_a_stage1(
|
||||
a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAfterGemm>{});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_ping);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
@@ -1075,6 +1090,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_pong);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -1122,9 +1140,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_pong);
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 1
|
||||
|
||||
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
@@ -21,7 +23,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
|
||||
{
|
||||
#if defined(__gfx950__) //|| defined(__gfx942__)
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
@@ -118,7 +120,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
|
||||
{
|
||||
#if defined(__gfx950__) //|| defined(__gfx942__)
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__)
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user