mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '45904b8fd7cde71dfc3741970325b3d552b06d27' into develop
This commit is contained in:
@@ -1122,7 +1122,8 @@ struct FmhaFwdPagedKVKernel
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_k;
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_k;
|
||||
|
||||
return make_page_block_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
@@ -1152,7 +1153,8 @@ struct FmhaFwdPagedKVKernel
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_v;
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_v;
|
||||
|
||||
return make_page_block_navigator<const VDataType, 1>(
|
||||
kargs.v_ptr,
|
||||
|
||||
@@ -441,28 +441,46 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col - kv_l2p_offset,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
// check columns in [aligned_physical_seqlen_k_start, physical_seqlen_k_end)
|
||||
if(kv_l2p_offset > 0)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start](auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return col < physical_seqlen_k_start_;
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user