Clarify the using of kSubQKHeaddim and kQKHeaddim

This commit is contained in:
Qianfeng Zhang
2025-12-03 08:18:13 +00:00
parent 7234b2fc1a
commit 2549bc1fee
5 changed files with 14 additions and 12 deletions

View File

@@ -586,7 +586,7 @@ struct HstuAttentionFwdKernel
return pad_tensor_view(k_dram_naive,
make_tuple(number<HstuAttentionPipeline::kN0>{},
number<HstuAttentionPipeline::kQKHeaddim>{}),
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQK>{});
}();
const auto v_dram = [&]() {
@@ -631,7 +631,7 @@ struct HstuAttentionFwdKernel
auto k_dram_window =
make_tile_window(k_dram,
make_tuple(number<HstuAttentionPipeline::kN0>{},
number<HstuAttentionPipeline::kQKHeaddim>{}),
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
{0, 0});
auto v_dram_window = make_tile_window(

View File

@@ -150,7 +150,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -179,7 +180,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
@@ -189,7 +190,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());

View File

@@ -178,7 +178,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
@@ -188,7 +188,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());

View File

@@ -152,7 +152,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -194,7 +195,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
@@ -204,7 +205,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());

View File

@@ -193,7 +193,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kSubQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
@@ -203,7 +203,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
make_tuple(number<kK1>{}, number<kSubQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());