mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
fix the lds alignment caused performance regression
This commit is contained in:
@@ -263,8 +263,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// S tile in LDS
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>())),
|
||||
Policy::template GetSmemSizeK<Problem>()),
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>());
|
||||
auto s_write_lds_window = make_tile_window(
|
||||
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
@@ -280,8 +279,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>()) +
|
||||
Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeS<Problem>()),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_write_window = make_tile_window(
|
||||
@@ -324,6 +322,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
|
||||
// move V tile windows
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
@@ -338,16 +342,42 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// __builtin_amdgcn_sched_barrier(
|
||||
// 0); // prevent from messing up the order of global loads
|
||||
// }
|
||||
if constexpr(1 < k0_loops)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
if constexpr(i_k0 == 0){
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
}
|
||||
else{
|
||||
block_sync_lds_direct_load<0>();
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
block_sync_lds();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)});
|
||||
}
|
||||
|
||||
// move V tile windows
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
if constexpr(k0_loops==1){
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
}
|
||||
else{
|
||||
block_sync_lds_direct_load<0>();
|
||||
}
|
||||
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
@@ -436,6 +466,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
|
||||
// move K tile windows after current status checked
|
||||
// prefetch next-tile along [K]ey sequence length dimension
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -113,7 +113,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -144,7 +144,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -194,7 +194,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
@@ -529,8 +529,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) + GetSmemSizeS<Problem>() +
|
||||
GetSmemSizeV<Problem>();
|
||||
// Alignment on gfx950 is 1280 Bytes
|
||||
// Alignment before gfx950 is 512 Bytes.
|
||||
return max(GetSmemSizeQ<Problem>(),
|
||||
GetSmemSizeK<Problem>() + GetSmemSizeS<Problem>() + GetSmemSizeV<Problem>());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user