not-critical updates in example and block_masking codes

This commit is contained in:
Qianfeng Zhang
2025-05-29 01:02:20 +00:00
parent 68a5ab8ff8
commit 36a0f2020c
2 changed files with 94 additions and 42 deletions

View File

@@ -391,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
HstuAttentionFwdParams params;
float scale_s = 1.0f / std::sqrt(hdim_qk);
if(is_jagged)
{
params.is_jagged = true;
@@ -405,7 +407,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
params.scale_s = scale_s;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
@@ -438,7 +440,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
params.scale_s = scale_s;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
@@ -507,7 +509,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_host_ref,
mask_host,
num_batch,
1.0f / std::sqrt(params.hdim_qk),
scale_s,
max_seqlen,
seq_offsets,
num_targets,

View File

@@ -132,37 +132,66 @@ struct HstuBlockMaskWithLocal
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row;
int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col;
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id)
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
if(contextual_seqlen > 0)
{
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return true;
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
return true;
}
else
{
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
return true;
}
return false;
}
else
{
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
return true;
int row_id = min(row, max_id);
int col_id = min(col, max_id);
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
return true;
}
else
{
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
return true;
}
return false;
}
return false;
};
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
@@ -303,26 +332,47 @@ struct HstuBlockMaskNoLocal
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row;
int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col;
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id)
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
if(contextual_seqlen > 0)
{
return (row_id > col_id) || (row == col);
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
return (row_id > col_id) || (row == col);
}
else
{
return (row_id != col_id) || (row == col);
};
}
else
{
return (row_id != col_id) || (row == col);
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = min(row, max_id);
int col_id = min(col, max_id);
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
return (row_id > col_id) || (row == col);
}
else
{
return (row_id != col_id) || (row == col);
};
};
};