Add support of loading QK tiles of hdim96 without padding to hdim128

This commit is contained in:
Qianfeng Zhang
2025-12-14 04:20:05 +00:00
parent 588f573ee1
commit 1cf868026b
10 changed files with 252 additions and 168 deletions

View File

@@ -59,7 +59,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
constexpr ck_tile::index_t occupancy = -1;
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,

View File

@@ -573,7 +573,7 @@ struct HstuAttentionFwdKernel
number<1>{});
return pad_tensor_view(q_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
number<HstuAttentionPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQK>{});
}();
const auto k_dram = [&]() {
@@ -586,7 +586,7 @@ struct HstuAttentionFwdKernel
return pad_tensor_view(k_dram_naive,
make_tuple(number<HstuAttentionPipeline::kN0>{},
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
number<HstuAttentionPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQK>{});
}();
const auto v_dram = [&]() {
@@ -624,14 +624,14 @@ struct HstuAttentionFwdKernel
make_tile_window(q_dram,
[&]() {
return make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kSubQKHeaddim>{});
number<HstuAttentionPipeline::kQKHeaddim>{});
}(),
{i_m0, 0});
auto k_dram_window =
make_tile_window(k_dram,
make_tuple(number<HstuAttentionPipeline::kN0>{},
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
number<HstuAttentionPipeline::kQKHeaddim>{}),
{0, 0});
auto v_dram_window = make_tile_window(

View File

@@ -126,7 +126,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
return Problem::template GetDramTileAccessMaxVectorSize<QDataType,
kBlockSize,
@@ -193,11 +193,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
{
return kKPerBlock * kNPerBlock;
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -244,11 +249,20 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::HstuAttentionTileSetting::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKVector>{},
number<1>{});
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -331,25 +345,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
@@ -357,25 +402,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
@@ -383,11 +459,36 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
{
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
{
constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
number<kKVector>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -500,24 +601,52 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentK<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
};
}
template <typename Problem>

View File

@@ -58,7 +58,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
{
constexpr ck_tile::index_t occupancy = -1;
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,

View File

@@ -37,7 +37,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax");
@@ -53,7 +53,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -127,9 +127,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
@@ -151,8 +151,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -184,7 +183,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
@@ -194,7 +193,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
@@ -226,7 +225,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
@@ -241,25 +239,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type = decltype(get_slice_tile(
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
@@ -382,7 +370,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[i_n0], k_tiles[i_n0], partition_index);
store_tile(k_lds_windows[i_n0], k_tiles[i_n0], partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -395,7 +383,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
@@ -535,7 +523,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile

View File

@@ -55,7 +55,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -126,9 +126,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
@@ -150,8 +150,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -183,7 +182,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
@@ -193,7 +192,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
@@ -219,7 +218,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
@@ -234,25 +232,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type = decltype(get_slice_tile(
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
@@ -351,7 +339,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
k_tiles[i_n0],
partition_index);
@@ -366,7 +354,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
@@ -481,8 +469,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,

View File

@@ -100,7 +100,7 @@ struct HstuAttentionFwdPipelineProblem
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim;
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kMPerBlock, kKPerBlock>();
}
@@ -108,7 +108,7 @@ struct HstuAttentionFwdPipelineProblem
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim;
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
}

View File

@@ -52,7 +52,7 @@ struct HstuAttentionFwdTileSettingClass
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size along k seqlen
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
@@ -60,6 +60,8 @@ struct HstuAttentionFwdTileSettingClass
// once (or repeately load Q as a whole tile)
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!");
};
} // namespace ck_tile

View File

@@ -37,7 +37,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax");
@@ -53,7 +53,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -127,9 +127,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
@@ -153,8 +153,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -198,7 +197,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
@@ -208,7 +207,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
@@ -245,7 +244,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
// when kQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
@@ -260,25 +259,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type = decltype(get_slice_tile(
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
@@ -401,7 +390,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
k_tiles[number<i_n0 % NumPrefetchK>{}],
partition_index);
@@ -428,7 +417,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
@@ -673,7 +662,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile

View File

@@ -55,7 +55,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -126,9 +126,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
@@ -152,8 +152,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -197,7 +196,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
@@ -207,7 +206,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
@@ -241,7 +240,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
@@ -255,25 +253,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type = decltype(get_slice_tile(
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
@@ -372,7 +360,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
k_tiles[number<i_n0 % NumPrefetchK>{}],
partition_index);
@@ -399,7 +387,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
@@ -638,8 +626,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,