mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Clarify the using of kSubQKHeaddim and kQKHeaddim so that less regular hdim (eg. 96, 160) can be efficiently supported
This commit is contained in:
@@ -208,9 +208,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
|
||||
@@ -218,18 +219,30 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -356,7 +369,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
v_tile_type v_tile;
|
||||
|
||||
store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
|
||||
store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
do
|
||||
{
|
||||
@@ -380,7 +393,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST)
|
||||
{
|
||||
@@ -448,7 +461,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
|
||||
};
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
@@ -533,7 +546,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_write_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -547,7 +560,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(k_lds_windows[number<0>{}],
|
||||
store_tile(k_lds_write_windows[number<0>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -103,7 +103,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -127,7 +127,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -170,7 +170,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -212,7 +212,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
@@ -301,7 +301,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
@@ -331,7 +331,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -438,7 +438,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
Reference in New Issue
Block a user