mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Update to hstu masking to separate the implementation for cross-attention and self-attention
This commit is contained in:
@@ -697,43 +697,94 @@ struct HstuAttentionFwdKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if(kargs.window_size > 0)
|
||||
{
|
||||
using HstuMaskType = typename ck_tile::HstuBlockMasking<kHasCausalMask, true>::Type;
|
||||
const auto mask =
|
||||
make_hstu_block_mask_with_local<HstuMaskType>(is_tile_in_first_split,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_kv,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
if(kargs.is_cross_attention)
|
||||
{
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<true, kHasCausalMask, true>::Type;
|
||||
|
||||
auto mask = make_hstu_cross_attention_block_mask_with_local<HstuMaskType>(
|
||||
is_tile_in_first_split,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_kv,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<false, kHasCausalMask, true>::Type;
|
||||
|
||||
auto mask = make_hstu_self_attention_block_mask_with_local<HstuMaskType>(
|
||||
is_tile_in_first_split,
|
||||
kargs.seqlen_q,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<kHasCausalMask, false>::Type;
|
||||
const auto mask = make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target);
|
||||
if(kargs.is_cross_attention)
|
||||
{
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<true, kHasCausalMask, false>::Type;
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
};
|
||||
auto mask = make_hstu_cross_attention_block_mask_without_local<HstuMaskType>(
|
||||
kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<false, kHasCausalMask, false>::Type;
|
||||
|
||||
auto mask = make_hstu_self_attention_block_mask_without_local<HstuMaskType>(
|
||||
kargs.seqlen_q, kargs.contextual_seqlen, num_target);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -9,10 +9,11 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kUseCausal>
|
||||
struct HstuBlockMaskWithLocal
|
||||
struct HstuCrossAttentionBlockMaskWithLocal
|
||||
{
|
||||
static constexpr bool kUseLocal = true;
|
||||
static constexpr bool IsMasking = true;
|
||||
static constexpr bool kUseLocal = true;
|
||||
static constexpr bool IsMasking = true;
|
||||
static constexpr bool kIsCrossAttention = true;
|
||||
|
||||
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
|
||||
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen_q); for other cases
|
||||
@@ -22,8 +23,8 @@ struct HstuBlockMaskWithLocal
|
||||
int seqlen_k;
|
||||
int contextual_seqlen;
|
||||
|
||||
int min_full_attn_seqlen;
|
||||
int max_attn_len;
|
||||
int min_full_attn_seqlen;
|
||||
|
||||
int max_q_uih_len;
|
||||
int max_k_uih_len;
|
||||
@@ -31,27 +32,28 @@ struct HstuBlockMaskWithLocal
|
||||
int max_col_id;
|
||||
int diff_q_kv_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_,
|
||||
int num_target_)
|
||||
CK_TILE_HOST_DEVICE HstuCrossAttentionBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_,
|
||||
int num_target_)
|
||||
: is_tile_in_first_split(is_tile_in_first_split_),
|
||||
seqlen_q(seqlen_q_),
|
||||
seqlen_k(seqlen_k_),
|
||||
contextual_seqlen(contextual_seqlen_),
|
||||
max_attn_len(max_attn_len_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
{
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
max_k_uih_len = seqlen_k - num_target_;
|
||||
|
||||
// in case user provided max_attn_len_ could be bigger than max_uih_len
|
||||
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len_));
|
||||
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len));
|
||||
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
|
||||
// with min_full_attn_seqlen scope
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not
|
||||
// collide with min_full_attn_seqlen scope
|
||||
contextual_seqlen = min(contextual_seqlen, max_q_uih_len - min_full_attn_seqlen);
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
@@ -181,7 +183,7 @@ struct HstuBlockMaskWithLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
@@ -229,56 +231,6 @@ struct HstuBlockMaskWithLocal
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
row += diff_q_kv_len;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == diff_q_kv_len && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope = !is_tile_in_first_split;
|
||||
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return res;
|
||||
}
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope = !is_tile_in_first_split;
|
||||
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
|
||||
template <index_t TileWidth, index_t TileHeight>
|
||||
CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,
|
||||
@@ -312,10 +264,246 @@ struct HstuBlockMaskWithLocal
|
||||
};
|
||||
|
||||
template <bool kUseCausal>
|
||||
struct HstuBlockMaskNoLocal
|
||||
struct HstuSelfAttentionBlockMaskWithLocal
|
||||
{
|
||||
static constexpr bool kUseLocal = false;
|
||||
static constexpr bool IsMasking = kUseCausal;
|
||||
static constexpr bool kUseLocal = true;
|
||||
static constexpr bool IsMasking = true;
|
||||
static constexpr bool kIsCrossAttention = false;
|
||||
|
||||
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
|
||||
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen_q); for other cases
|
||||
// and tiles, is_tile_in_first_split is true
|
||||
bool is_tile_in_first_split;
|
||||
int seqlen;
|
||||
int contextual_seqlen;
|
||||
|
||||
int max_attn_len;
|
||||
int min_full_attn_seqlen;
|
||||
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuSelfAttentionBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_,
|
||||
int num_target_)
|
||||
: is_tile_in_first_split(is_tile_in_first_split_),
|
||||
seqlen(seqlen_),
|
||||
contextual_seqlen(contextual_seqlen_),
|
||||
max_attn_len(max_attn_len_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
|
||||
// in case user provided max_attn_len_ could be bigger than max_uih_len
|
||||
max_attn_len = min(max_uih_len, max_attn_len);
|
||||
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not
|
||||
// collide with min_full_attn_seqlen scope
|
||||
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_id = max_uih_len - (contextual_seqlen - 1);
|
||||
else
|
||||
max_id = max_uih_len;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
// tile is partitially or completely in [max_uih_len-min_full_attn_seqlen,
|
||||
// max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
}
|
||||
else // tile is completely inside [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
// is_tile_in_first_split is true, either min_full_attn_seqlen is 0 or tile is
|
||||
// in [0, max_uih_len-min_full_attn_seqlen)
|
||||
if constexpr(!kUseCausal)
|
||||
{
|
||||
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
|
||||
{
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_start_aligned = x_start - x_start % XTile;
|
||||
|
||||
// some rows of the tile in [max_uih_len -max_attn_len, max_uih_len)
|
||||
if(i_y + YTile > max_uih_len - max_attn_len)
|
||||
{
|
||||
return ck_tile::make_tuple(x_start_aligned, seqlen);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len
|
||||
// -max_attn_len)
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
return ck_tile::make_tuple(x_start_aligned, x_end);
|
||||
};
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
}
|
||||
else // for i_y < contextual_seqlen + max_attn_len
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
else // kUseCausal && kUseLocal
|
||||
{
|
||||
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
}
|
||||
else // for i_y < contextual_seqlen + max_attn_len
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
}
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
}
|
||||
};
|
||||
|
||||
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
|
||||
template <index_t TileWidth, index_t TileHeight>
|
||||
CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,
|
||||
index_t i_tile_left,
|
||||
number<TileWidth>,
|
||||
number<TileHeight>) const
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
|
||||
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
|
||||
// 1) tile is completely in [max_uih_len-min_full_attn_seqlen, max_uih_len]
|
||||
// 2) some row of tile is in [max_uih_len, seqlen], requires i_tile_right <=
|
||||
// max_uih_len to return true
|
||||
if(!is_tile_in_first_split &&
|
||||
(i_tile_bottom <= max_uih_len || i_tile_right <= max_uih_len))
|
||||
return true;
|
||||
};
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool kUseCausal>
|
||||
struct HstuCrossAttentionBlockMaskNoLocal
|
||||
{
|
||||
static constexpr bool kUseLocal = false;
|
||||
static constexpr bool IsMasking = kUseCausal;
|
||||
static constexpr bool kIsCrossAttention = true;
|
||||
|
||||
int seqlen_q;
|
||||
int seqlen_k;
|
||||
@@ -328,7 +516,10 @@ struct HstuBlockMaskNoLocal
|
||||
int diff_q_kv_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_)
|
||||
HstuCrossAttentionBlockMaskNoLocal(int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target_)
|
||||
: seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_)
|
||||
{
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
@@ -391,8 +582,8 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen
|
||||
// and max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
@@ -404,14 +595,14 @@ struct HstuBlockMaskNoLocal
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen
|
||||
// and max_uih_len
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on
|
||||
// the diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
@@ -434,8 +625,9 @@ struct HstuBlockMaskNoLocal
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
// assume num_target > 0 with high probability, don't check whether num_target
|
||||
// is 0; so if num_target is 0, IsTokenPairInsideMask() will be called for the
|
||||
// bottom tile
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top + diff_q_kv_len)
|
||||
return false;
|
||||
|
||||
@@ -446,8 +638,9 @@ struct HstuBlockMaskNoLocal
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
// assume num_target > 0 with high probability, don't check whether num_target
|
||||
// is 0; so if num_target is 0, IsTokenPairInsideMask() will be called for the
|
||||
// bottom tile
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right >= max_k_uih_len)
|
||||
return false;
|
||||
|
||||
@@ -456,23 +649,163 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
template <bool kUseCausal>
|
||||
struct HstuSelfAttentionBlockMaskNoLocal
|
||||
{
|
||||
static constexpr bool kUseLocal = false;
|
||||
static constexpr bool IsMasking = kUseCausal;
|
||||
static constexpr bool kIsCrossAttention = false;
|
||||
|
||||
int seqlen;
|
||||
int contextual_seqlen;
|
||||
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
HstuSelfAttentionBlockMaskNoLocal(int seqlen_, int contextual_seqlen_, int num_target_)
|
||||
: seqlen(seqlen_), contextual_seqlen(contextual_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_id = max_uih_len - (contextual_seqlen - 1);
|
||||
else
|
||||
max_id = max_uih_len;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
if(i_y < contextual_seqlen)
|
||||
{
|
||||
if(i_y + YTile > max_uih_len)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen
|
||||
// and max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen
|
||||
// and max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on
|
||||
// the diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
};
|
||||
|
||||
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
|
||||
template <index_t TileWidth, index_t TileHeight>
|
||||
CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,
|
||||
index_t i_tile_left,
|
||||
number<TileWidth>,
|
||||
number<TileHeight>) const
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target
|
||||
// is 0; so if num_target is 0, IsTokenPairInsideMask() will be called for the
|
||||
// bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + (TileWidth - 1);
|
||||
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target
|
||||
// is 0; so if num_target is 0, IsTokenPairInsideMask() will be called for the
|
||||
// bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kIsCrossAttention, bool kUseCausal, bool kUseLocal>
|
||||
struct HstuBlockMasking
|
||||
{
|
||||
using Type = std::conditional_t<kUseLocal,
|
||||
HstuBlockMaskWithLocal<kUseCausal>,
|
||||
HstuBlockMaskNoLocal<kUseCausal>>;
|
||||
using Type =
|
||||
std::conditional_t<kUseLocal,
|
||||
std::conditional_t<kIsCrossAttention,
|
||||
HstuCrossAttentionBlockMaskWithLocal<kUseCausal>,
|
||||
HstuSelfAttentionBlockMaskWithLocal<kUseCausal>>,
|
||||
std::conditional_t<kIsCrossAttention,
|
||||
HstuCrossAttentionBlockMaskNoLocal<kUseCausal>,
|
||||
HstuSelfAttentionBlockMaskNoLocal<kUseCausal>>>;
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_hstu_cross_attention_block_mask_with_local(bool is_tile_in_first_split_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
static_assert(HstuBlockMaskType::kIsCrossAttention == true);
|
||||
|
||||
return HstuBlockMaskType{is_tile_in_first_split_,
|
||||
seqlen_q_,
|
||||
seqlen_k_,
|
||||
@@ -483,12 +816,40 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_without_local(int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target)
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_hstu_self_attention_block_mask_with_local(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
static_assert(HstuBlockMaskType::kIsCrossAttention == false);
|
||||
|
||||
return HstuBlockMaskType{is_tile_in_first_split_,
|
||||
seqlen_,
|
||||
contextual_seqlen_,
|
||||
max_attn_len_,
|
||||
min_full_attn_seqlen_,
|
||||
num_target};
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_cross_attention_block_mask_without_local(
|
||||
int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target)
|
||||
{
|
||||
static_assert(HstuBlockMaskType::kIsCrossAttention == true);
|
||||
|
||||
return HstuBlockMaskType{seqlen_q_, seqlen_k_, contextual_seqlen_, num_target};
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_self_attention_block_mask_without_local(
|
||||
int seqlen_, int contextual_seqlen_, int num_target)
|
||||
{
|
||||
static_assert(HstuBlockMaskType::kIsCrossAttention == false);
|
||||
|
||||
return HstuBlockMaskType{seqlen_, contextual_seqlen_, num_target};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -54,8 +54,6 @@ struct reference_hstu_attention
|
||||
int min_full_attn_seqlen) // define masking length at the end of query token
|
||||
// sequence which is included for full attention
|
||||
{
|
||||
ignore = is_cross_attention;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// check the number of batches
|
||||
@@ -122,34 +120,67 @@ struct reference_hstu_attention
|
||||
? attn_scale
|
||||
: 1.0f / static_cast<float>(max(max_seqlen_q, max_seqlen_kv));
|
||||
|
||||
BOOL_SWITCH(window_size > 0, kHasLocal, [&] {
|
||||
using HstuMaskType = typename HstuBlockMasking<kUseCausal, kHasLocal>::Type;
|
||||
BOOL_SWITCH_2(window_size > 0, kHasLocal, is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuMaskType =
|
||||
typename HstuBlockMasking<kIsCrossAttention, kUseCausal, kHasLocal>::Type;
|
||||
|
||||
HstuMaskType mask = [&]() {
|
||||
if constexpr(kHasLocal)
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
|
||||
// user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen_q - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
{
|
||||
if constexpr(kIsCrossAttention)
|
||||
{
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if
|
||||
// the user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen_q - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_cross_attention_block_mask_with_local<
|
||||
HstuMaskType>(true,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
else
|
||||
return ck_tile::make_hstu_cross_attention_block_mask_with_local<
|
||||
HstuMaskType>(true,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
seqlen_q - num_target);
|
||||
}
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
seqlen_q - num_target);
|
||||
{
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if
|
||||
// the user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen_q - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_self_attention_block_mask_with_local<
|
||||
HstuMaskType>(true,
|
||||
seqlen_q,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
else
|
||||
return ck_tile::make_hstu_self_attention_block_mask_with_local<
|
||||
HstuMaskType>(true,
|
||||
seqlen_q,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
seqlen_q - num_target);
|
||||
}
|
||||
}
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
seqlen_q, seqlen_kv, contextual_seqlen, num_target);
|
||||
{
|
||||
if constexpr(kIsCrossAttention)
|
||||
return ck_tile::make_hstu_cross_attention_block_mask_without_local<
|
||||
HstuMaskType>(seqlen_q, seqlen_kv, contextual_seqlen, num_target);
|
||||
else
|
||||
return ck_tile::make_hstu_self_attention_block_mask_without_local<
|
||||
HstuMaskType>(seqlen_q, contextual_seqlen, num_target);
|
||||
}
|
||||
}();
|
||||
|
||||
if(save_mask)
|
||||
|
||||
Reference in New Issue
Block a user