mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Pass partition_index to get_x_indices_from_distributed_indices() to reduce calls of __builtin_amdgcn_readfirstlane()
This commit is contained in:
@@ -519,7 +519,8 @@ struct HstuAttentionFwdKernel
|
||||
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0));
|
||||
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
|
||||
@@ -361,7 +361,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
pcomp_tile.get_tile_distribution(),
|
||||
make_tuple(idx0, idx1),
|
||||
partition_index);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
|
||||
@@ -363,7 +363,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
pcomp_tile.get_tile_distribution(),
|
||||
make_tuple(idx0, idx1),
|
||||
partition_index);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
|
||||
@@ -392,7 +392,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
pcomp_tile.get_tile_distribution(),
|
||||
make_tuple(idx0, idx1),
|
||||
partition_index);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
|
||||
@@ -394,7 +394,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
pcomp_tile.get_tile_distribution(),
|
||||
make_tuple(idx0, idx1),
|
||||
partition_index);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user