Remove the comparing of row/col to max_uih_len in masking

This commit is contained in:
Qianfeng Zhang
2025-04-16 04:35:42 +00:00
parent d1749b3aae
commit 226a254723
3 changed files with 57 additions and 29 deletions

View File

@@ -575,7 +575,7 @@ struct HstuAttentionFwdKernel
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen, kargs.hdim_qk),
make_tuple(mask.max_uih_len, kargs.hdim_qk),
make_tuple(kargs.seq_stride_q, 1),
number<HstuAttentionPipeline::kAlignmentQ>{},
number<1>{});
@@ -584,20 +584,20 @@ struct HstuAttentionFwdKernel
return pad_tensor_view(q_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
sequence<false, kPadHeadDimQK>{});
}
else
{
return pad_tensor_view(q_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
sequence<false, kPadHeadDimQK>{});
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen, kargs.hdim_qk),
make_tuple(mask.max_uih_len, kargs.hdim_qk),
make_tuple(kargs.seq_stride_k, 1),
number<HstuAttentionPipeline::kAlignmentK>{},
number<1>{});
@@ -612,7 +612,7 @@ struct HstuAttentionFwdKernel
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(mask.max_uih_len, kargs.hdim_v),
make_tuple(kargs.seq_stride_v, 1),
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
@@ -732,7 +732,7 @@ struct HstuAttentionFwdKernel
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(mask.max_uih_len, kargs.hdim_v),
make_tuple(kargs.seq_stride_o, 1),
number<HstuAttentionPipeline::kAlignmentO>{},
number<1>{});

View File

@@ -388,20 +388,8 @@ struct HstuAttentionFwdPipelineQRKSVS
return !mask.IsTokenPairInsideMask(row, col);
});
}
else
{
if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
};
};
auto s = cast_tile<CompDataType>(s_acc); // S{j}
auto s = cast_tile<CompDataType>(s_acc);
tile_elementwise_inout(f_silu, s);

View File

@@ -68,7 +68,7 @@ struct HstuBlockMaskWithLocal
};
}
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
{
if(row >= max_uih_len || col >= max_uih_len)
return false;
@@ -87,6 +87,25 @@ struct HstuBlockMaskWithLocal
return result;
};
// masking codes in device don't have to compare row/col with max_uih_len, since
// buffer_load_xxx instruction is able to return zero for out-of-boundary access
CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
{
if(row < contextual_seqlen)
return true;
bool result = false;
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = abs(row - col) <= max_attn_len;
if(min_full_attn_seqlen > 0)
result = result || (row >= max_uih_len - min_full_attn_seqlen);
return result;
};
};
template <bool kUseCausal>
@@ -123,7 +142,7 @@ struct HstuBlockMaskNoLocal
};
}
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
{
if(row >= max_uih_len || col >= max_uih_len)
return false;
@@ -140,6 +159,23 @@ struct HstuBlockMaskNoLocal
return true;
};
// masking codes in device don't have to compare row/col with max_uih_len, since
// buffer_load_xxx instruction is able to return zero for out-of-boundary access
CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
{
if(row < contextual_seqlen)
return true;
if constexpr(IsMasking)
{
bool result = (row >= col);
return result;
}
return true;
};
};
template <bool kUseCausal, bool kUseLocal>
@@ -157,10 +193,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
int max_attn_len_,
int min_full_attn_seqlen_)
{
auto max_uih_len_ = seqlen_;
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
max_uih_len_ -= num_target;
auto max_uih_len_ = [&]() {
if(contextual_seqlen_ > 0)
return seqlen_ - (contextual_seqlen_ - 1) - num_target;
else
return seqlen_ - num_target;
}();
return HstuBlockMaskType{
contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_};
@@ -170,10 +208,12 @@ template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
{
auto max_uih_len_ = seqlen_;
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
max_uih_len_ -= num_target;
auto max_uih_len_ = [&]() {
if(contextual_seqlen_ > 0)
return seqlen_ - (contextual_seqlen_ - 1) - num_target;
else
return seqlen_ - num_target;
}();
return HstuBlockMaskType{contextual_seqlen_, max_uih_len_};
};