Clarify the using of kSubQKHeaddim and kQKHeaddim so that less regular hdim (eg. 96, 160) can be efficiently supported

This commit is contained in:
Qianfeng Zhang
2025-09-09 09:51:36 +00:00
parent f8dea2bc86
commit 72eb4e95d8
2 changed files with 33 additions and 20 deletions

View File

@@ -208,9 +208,10 @@ struct HstuAttentionFwdPipelineQRKSVS
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
@@ -218,18 +219,30 @@ struct HstuAttentionFwdPipelineQRKSVS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window = make_tile_window(
auto k_lds_write_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});
// V tile in LDS
@@ -356,7 +369,7 @@ struct HstuAttentionFwdPipelineQRKSVS
v_tile_type v_tile;
store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
do
{
@@ -380,7 +393,7 @@ struct HstuAttentionFwdPipelineQRKSVS
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST)
{
@@ -448,7 +461,7 @@ struct HstuAttentionFwdPipelineQRKSVS
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
};
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
@@ -533,7 +546,7 @@ struct HstuAttentionFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
store_tile(k_lds_write_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -547,7 +560,7 @@ struct HstuAttentionFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(k_lds_windows[number<0>{}],
store_tile(k_lds_write_windows[number<0>{}],
tile_elementwise_in(k_element_func, k_tile));
}
});

View File

@@ -103,7 +103,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
@@ -127,7 +127,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
@@ -170,7 +170,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
@@ -212,7 +212,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
@@ -301,7 +301,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
@@ -331,7 +331,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
{
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
@@ -438,7 +438,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;