mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Change in HstBlockMasking and kernel/reference codes for using masking
This commit is contained in:
@@ -80,6 +80,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
@@ -97,7 +98,6 @@ struct HstuAttentionFwdKernel
|
||||
struct HstuAttentionFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
@@ -184,8 +184,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
@@ -207,10 +207,11 @@ struct HstuAttentionFwdKernel
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -226,7 +227,6 @@ struct HstuAttentionFwdKernel
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.contextual_seqlen = contextual_seqlen;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
@@ -267,8 +267,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
@@ -300,8 +300,8 @@ struct HstuAttentionFwdKernel
|
||||
batch_stride_bias,
|
||||
batch_stride_o,
|
||||
num_targets_ptr,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
@@ -330,8 +330,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
@@ -353,10 +353,11 @@ struct HstuAttentionFwdKernel
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seq_offsets_ptr)};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
@@ -368,7 +369,6 @@ struct HstuAttentionFwdKernel
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.contextual_seqlen = contextual_seqlen;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
@@ -404,8 +404,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
@@ -432,8 +432,8 @@ struct HstuAttentionFwdKernel
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
num_targets_ptr,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
@@ -539,30 +539,17 @@ struct HstuAttentionFwdKernel
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
int max_uih_len = kargs.seqlen;
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
if(kargs.contextual_seqlen > 0)
|
||||
max_uih_len -= kargs.contextual_seqlen - 1;
|
||||
};
|
||||
|
||||
if(kargs.num_targets_ptr != nullptr)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
max_uih_len -= kargs.num_targets_ptr[i_batch];
|
||||
else
|
||||
max_uih_len -= kargs.num_targets_ptr[0];
|
||||
};
|
||||
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{kargs.window_size,
|
||||
kargs.contextual_seqlen,
|
||||
kargs.min_full_attn_seqlen,
|
||||
max_uih_len};
|
||||
kargs.seqlen,
|
||||
num_target};
|
||||
else
|
||||
return HstuMask{0, 0, 0, 0};
|
||||
return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target};
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -373,21 +373,21 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
if constexpr(HstuMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsTokenPairInsideMask(row, col);
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else if constexpr(kPadSeqLenK)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
if(i_loop < num_loops)
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
if(i_loop < num_loops - 1)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsTokenPairInsideMask(row, col);
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -20,12 +20,17 @@ struct HstuBlockMasking
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_,
|
||||
int contextual_seqlen_,
|
||||
int min_full_attn_seqlen_,
|
||||
int max_uih_len_)
|
||||
int seqlen_,
|
||||
int num_target)
|
||||
{
|
||||
max_attn_len = max_attn_len_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
min_full_attn_seqlen = min_full_attn_seqlen_;
|
||||
max_uih_len = max_uih_len_;
|
||||
|
||||
max_uih_len = seqlen_;
|
||||
|
||||
max_uih_len -= contextual_seqlen - 1;
|
||||
max_uih_len -= num_target;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
@@ -82,27 +87,34 @@ struct HstuBlockMasking
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row >= max_uih_len || col >= max_uih_len)
|
||||
return false;
|
||||
|
||||
if(row < contextual_seqlen)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(kUseLocal)
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
bool result = false;
|
||||
if constexpr(kUseLocal)
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
}
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
result = (row >= col);
|
||||
};
|
||||
};
|
||||
|
||||
return result;
|
||||
return result;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ struct reference_hstu_attention
|
||||
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
|
||||
// check num_tagets
|
||||
assert(num_tagets.empty() || num_targets.size() == num_batch);
|
||||
|
||||
auto silu = [](CompDataType x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
@@ -91,33 +94,22 @@ struct reference_hstu_attention
|
||||
return sigmod_val * x;
|
||||
};
|
||||
|
||||
bool has_target = !num_targets.empty();
|
||||
|
||||
if(has_target)
|
||||
assert(num_targets.size() == num_batch);
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
|
||||
: q_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
|
||||
int max_uih_len = seqlen;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_uih_len -= contextual_seqlen - 1;
|
||||
|
||||
if(has_target)
|
||||
max_uih_len -= num_targets[i_batch];
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
HstuBlockMasking<kUseCausal, kUseLocal> mask{
|
||||
max_attn_len, contextual_seqlen, min_full_attn_seqlen, max_uih_len};
|
||||
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < max_uih_len; sq++)
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
{
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
{
|
||||
if(mask.IsTokenPairInsideMask(sq, sk))
|
||||
{
|
||||
@@ -153,14 +145,14 @@ struct reference_hstu_attention
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem) / ck_tile::type_convert<CompDataType>(seqlen);
|
||||
elem = silu(elem);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
|
||||
3
example/ck_tile/18_hstu_attention/test_hstu_attention.sh
Normal file
3
example/ck_tile/18_hstu_attention/test_hstu_attention.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733 -causal=1 -local_len=5 -context_len=6 -minfull_len=6
|
||||
Reference in New Issue
Block a user