mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fixed block Q with M
This commit is contained in:
@@ -384,7 +384,6 @@ struct UnifiedAttentionPipeline
|
||||
[[maybe_unused]] const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
[[maybe_unused]] const VElementFunction& v_element_func,
|
||||
index_t num_queries_per_kv,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
@@ -395,16 +394,13 @@ struct UnifiedAttentionPipeline
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
// TODO do we make the num_queries_per_kv and num_head conexpr???
|
||||
const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -414,29 +410,29 @@ struct UnifiedAttentionPipeline
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
[[maybe_unused]] auto s_lds_window =
|
||||
make_tile_window(s_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
make_tile_window(s_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
|
||||
auto p_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSize<Problem>()),
|
||||
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
[[maybe_unused]] auto p_lds_window =
|
||||
make_tile_window(p_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
make_tile_window(p_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
|
||||
auto o_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
[[maybe_unused]] auto o_lds_window =
|
||||
make_tile_window(o_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
make_tile_window(o_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
|
||||
|
||||
auto m_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSize<Problem>()),
|
||||
MakeSimpleLdsDesc1D<BLOCK_Q>());
|
||||
MakeSimpleLdsDesc1D<BLOCK_M>());
|
||||
[[maybe_unused]] auto m_lds_window =
|
||||
make_tile_window(m_lds, make_tuple(number<BLOCK_Q>{}), {0});
|
||||
make_tile_window(m_lds, make_tuple(number<BLOCK_M>{}), {0});
|
||||
|
||||
const index_t warp_group_id = get_warp_id() / 4;
|
||||
|
||||
@@ -543,7 +539,7 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{}, number<BLOCK_SIZE>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE);
|
||||
index_t kv_token_start = seqlen_k_start;
|
||||
@@ -812,7 +808,7 @@ struct UnifiedAttentionPipeline
|
||||
gemm_0(sp(sp_reg_idx).sp_compute,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
|
||||
sequence<BLOCK_M, k0_loops * HEAD_SIZE_PADDED>{}),
|
||||
get_slice_tile(kv_tile.k_tile,
|
||||
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
|
||||
@@ -822,7 +818,7 @@ struct UnifiedAttentionPipeline
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(sp(sp_reg_idx).p,
|
||||
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
|
||||
sequence<BLOCK_M, k1_loops * HEAD_SIZE_PADDED>{}),
|
||||
get_slice_tile(kv_tile.v_tile,
|
||||
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
|
||||
@@ -836,7 +832,7 @@ struct UnifiedAttentionPipeline
|
||||
gemm_0(sp(sp_reg_idx).sp_compute,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
|
||||
sequence<BLOCK_M, k0_loops * HEAD_SIZE_PADDED>{}),
|
||||
get_slice_tile(kv_tile.k_tile,
|
||||
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
|
||||
@@ -846,7 +842,7 @@ struct UnifiedAttentionPipeline
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(sp(sp_reg_idx).p,
|
||||
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
|
||||
sequence<BLOCK_M, k1_loops * HEAD_SIZE_PADDED>{}),
|
||||
get_slice_tile(kv_tile.v_tile,
|
||||
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
|
||||
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
|
||||
@@ -894,7 +890,7 @@ struct UnifiedAttentionPipeline
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
q_origin.at(number<0>{}), kv_token_start, number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
|
||||
q_origin.at(number<0>{}), kv_token_start, number<BLOCK_M>{}, number<BLOCK_SIZE>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(sp(sp_reg_idx).sp_compute,
|
||||
@@ -1209,7 +1205,6 @@ struct UnifiedAttentionPipeline
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
index_t num_queries_per_kv,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
FmhaMask mask,
|
||||
@@ -1224,7 +1219,6 @@ struct UnifiedAttentionPipeline
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
num_queries_per_kv,
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user