Change in HstBlockMasking and kernel/reference codes for using masking

This commit is contained in:
Qianfeng Zhang
2025-04-03 14:46:12 +00:00
parent 733734553b
commit 10e72d3362
5 changed files with 66 additions and 72 deletions

View File

@@ -80,6 +80,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t nhead_stride_o;
const int32_t* num_targets_ptr;
ck_tile::index_t contextual_seqlen;
};
struct HstuAttentionFwdCommonBiasKargs
@@ -97,7 +98,6 @@ struct HstuAttentionFwdKernel
struct HstuAttentionFwdMaskKargs
{
ck_tile::index_t window_size;
ck_tile::index_t contextual_seqlen;
ck_tile::index_t min_full_attn_seqlen;
};
@@ -184,8 +184,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
const void* num_targets_ptr,
ck_tile::index_t window_size,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
@@ -207,10 +207,11 @@ struct HstuAttentionFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(num_targets_ptr),
contextual_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -226,7 +227,6 @@ struct HstuAttentionFwdKernel
if constexpr(kHasMask)
{
kargs.window_size = window_size;
kargs.contextual_seqlen = contextual_seqlen;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasDropout)
@@ -267,8 +267,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
const void* num_targets_ptr,
ck_tile::index_t window_size,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
uint64_t philox_seed,
@@ -300,8 +300,8 @@ struct HstuAttentionFwdKernel
batch_stride_bias,
batch_stride_o,
num_targets_ptr,
window_size,
contextual_seqlen,
window_size,
min_full_attn_seqlen,
p_drop,
std::make_pair(philox_seed, philox_offset));
@@ -330,8 +330,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
const void* num_targets_ptr,
ck_tile::index_t window_size,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
@@ -353,10 +353,11 @@ struct HstuAttentionFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(num_targets_ptr),
contextual_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seq_offsets_ptr)};
if constexpr(kHasBias)
@@ -368,7 +369,6 @@ struct HstuAttentionFwdKernel
if constexpr(kHasMask)
{
kargs.window_size = window_size;
kargs.contextual_seqlen = contextual_seqlen;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasDropout)
@@ -404,8 +404,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
const void* num_targets_ptr,
ck_tile::index_t window_size,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
uint64_t philox_seed,
@@ -432,8 +432,8 @@ struct HstuAttentionFwdKernel
nhead_stride_bias,
nhead_stride_o,
num_targets_ptr,
window_size,
contextual_seqlen,
window_size,
min_full_attn_seqlen,
p_drop,
std::make_pair(philox_seed, philox_offset));
@@ -539,30 +539,17 @@ struct HstuAttentionFwdKernel
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
int max_uih_len = kargs.seqlen;
if constexpr(kHasMask)
{
if(kargs.contextual_seqlen > 0)
max_uih_len -= kargs.contextual_seqlen - 1;
};
if(kargs.num_targets_ptr != nullptr)
{
if constexpr(kIsJagged)
max_uih_len -= kargs.num_targets_ptr[i_batch];
else
max_uih_len -= kargs.num_targets_ptr[0];
};
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
HstuMask mask = [&]() {
if constexpr(kHasMask)
return HstuMask{kargs.window_size,
kargs.contextual_seqlen,
kargs.min_full_attn_seqlen,
max_uih_len};
kargs.seqlen,
num_target};
else
return HstuMask{0, 0, 0, 0};
return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target};
}();
// for simplicity, batch stride we just modify the pointer

View File

@@ -373,21 +373,21 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(HstuMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsTokenPairInsideMask(row, col);
return !mask.IsTokenPairInsideMask(row, col);
});
}
else if constexpr(kPadSeqLenK)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
if(i_loop < num_loops)
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
if(i_loop < num_loops - 1)
return false;
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsTokenPairInsideMask(row, col);
return !mask.IsTokenPairInsideMask(row, col);
});
};

View File

@@ -20,12 +20,17 @@ struct HstuBlockMasking
CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_,
int contextual_seqlen_,
int min_full_attn_seqlen_,
int max_uih_len_)
int seqlen_,
int num_target)
{
max_attn_len = max_attn_len_;
contextual_seqlen = contextual_seqlen_;
min_full_attn_seqlen = min_full_attn_seqlen_;
max_uih_len = max_uih_len_;
max_uih_len = seqlen_;
max_uih_len -= contextual_seqlen - 1;
max_uih_len -= num_target;
};
// to get the loop length along X axis, return index:[start, end), end-start=length
@@ -82,27 +87,34 @@ struct HstuBlockMasking
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
{
if(row >= max_uih_len || col >= max_uih_len)
return false;
if(row < contextual_seqlen)
return true;
bool result = false;
if constexpr(kUseLocal)
if constexpr(IsMasking)
{
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
bool result = false;
if constexpr(kUseLocal)
{
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = std::abs(row - col) <= max_attn_len;
if(min_full_attn_seqlen > 0)
result = result || (row >= max_uih_len - min_full_attn_seqlen);
}
else
result = std::abs(row - col) <= max_attn_len;
if(min_full_attn_seqlen > 0)
result = result || (row >= max_uih_len - min_full_attn_seqlen);
}
else
{
if constexpr(kUseCausal)
{
result = (row >= col);
};
};
return result;
return result;
}
return true;
};
};

View File

@@ -83,6 +83,9 @@ struct reference_hstu_attention
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
// check num_tagets
assert(num_tagets.empty() || num_targets.size() == num_batch);
auto silu = [](CompDataType x) {
auto one = ck_tile::type_convert<CompDataType>(1.0f);
@@ -91,33 +94,22 @@ struct reference_hstu_attention
return sigmod_val * x;
};
bool has_target = !num_targets.empty();
if(has_target)
assert(num_targets.size() == num_batch);
auto f = [&](auto i_batch, auto i_head) {
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
: q_batch_seq_nhead_hdim.get_lengths()[1];
int max_uih_len = seqlen;
if(contextual_seqlen > 0)
max_uih_len -= contextual_seqlen - 1;
if(has_target)
max_uih_len -= num_targets[i_batch];
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
HstuBlockMasking<kUseCausal, kUseLocal> mask{
max_attn_len, contextual_seqlen, min_full_attn_seqlen, max_uih_len};
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
// for all rows in the batch
for(int sq = 0; sq < max_uih_len; sq++)
for(int sq = 0; sq < seqlen; sq++)
{
std::vector<CompDataType> locals;
// for all cols in the batch
for(int sk = 0; sk < max_uih_len; sk++)
for(int sk = 0; sk < seqlen; sk++)
{
if(mask.IsTokenPairInsideMask(sq, sk))
{
@@ -153,14 +145,14 @@ struct reference_hstu_attention
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem) / ck_tile::type_convert<CompDataType>(seqlen);
elem = silu(elem);
// second Gemm
for(int k = 0; k < hdim_v; k++)
{
GemmAccDataType dot_prod = 0.f;
for(int sk = 0; sk < max_uih_len; sk++)
for(int sk = 0; sk < seqlen; sk++)
{
if constexpr(kIsJagged)
{

View File

@@ -0,0 +1,3 @@
#!/bin/bash
bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733 -causal=1 -local_len=5 -context_len=6 -minfull_len=6