mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
not-critical updates in example and block_masking codes
This commit is contained in:
@@ -391,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
HstuAttentionFwdParams params;
|
||||
|
||||
float scale_s = 1.0f / std::sqrt(hdim_qk);
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
params.is_jagged = true;
|
||||
@@ -405,7 +407,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
|
||||
params.scale_s = scale_s;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
@@ -438,7 +440,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
|
||||
params.scale_s = scale_s;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
@@ -507,7 +509,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
1.0f / std::sqrt(params.hdim_qk),
|
||||
scale_s,
|
||||
max_seqlen,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
|
||||
@@ -132,37 +132,66 @@ struct HstuBlockMaskWithLocal
|
||||
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row;
|
||||
int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col;
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
|
||||
@@ -303,26 +332,47 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row;
|
||||
int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col;
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user