fix the lds alignment caused performance regression

This commit is contained in:
aska-0096
2025-07-25 07:10:01 +00:00
parent af28123cec
commit 13bcc913de
3 changed files with 49 additions and 16 deletions

View File

@@ -263,8 +263,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template GetSmemSizeK<Problem>()),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
@@ -280,8 +279,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>()) +
Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeS<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_write_window = make_tile_window(
@@ -324,6 +322,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
do
{
block_sync_lds();
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
// move V tile windows
move_tile_window(v_dram_window, {kK1, 0});
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
@@ -338,16 +342,42 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
// __builtin_amdgcn_sched_barrier(
// 0); // prevent from messing up the order of global loads
// }
if constexpr(1 < k0_loops)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
if constexpr(i_k0 == 0){
block_sync_lds_direct_load<v_vmem_insts>();
}
else{
block_sync_lds_direct_load<0>();
}
block_sync_lds();
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
auto k_tile = load_tile(k_lds_read_window);
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_tile);
// loop over along the [K]ey head dimension
move_tile_window(k_dram_window, {0, kK0});
block_sync_lds();
async_load_tile(k_lds_write_window, k_dram_window);
});
// move back to the origin
move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)});
}
// move V tile windows
move_tile_window(v_dram_window, {kK1, 0});
if constexpr(k0_loops==1){
block_sync_lds_direct_load<v_vmem_insts>();
}
else{
block_sync_lds_direct_load<0>();
}
block_sync_lds_direct_load<v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_window);
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
@@ -436,6 +466,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
// move K tile windows after current status checked
// prefetch next-tile along [K]ey sequence length dimension
move_tile_window(k_dram_window, {kN0, 0});
block_sync_lds();

View File

@@ -113,7 +113,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
@@ -144,7 +144,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
@@ -194,7 +194,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
@@ -529,8 +529,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) + GetSmemSizeS<Problem>() +
GetSmemSizeV<Problem>();
// Alignment on gfx950 is 1280 Bytes
// Alignment before gfx950 is 512 Bytes.
return max(GetSmemSizeQ<Problem>(),
GetSmemSizeK<Problem>() + GetSmemSizeS<Problem>() + GetSmemSizeV<Problem>());
}
};