Fixed block Q with M

This commit is contained in:
Tianxing Wu
2025-10-23 12:02:18 +00:00
parent d18f8e46bf
commit 89cfdb35e0

View File

@@ -384,7 +384,6 @@ struct UnifiedAttentionPipeline
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
@@ -395,16 +394,13 @@ struct UnifiedAttentionPipeline
void* smem_ptr) const
{
using namespace ck_tile;
// TODO do we make the num_queries_per_kv and num_head conexpr???
const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -414,29 +410,29 @@ struct UnifiedAttentionPipeline
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
[[maybe_unused]] auto s_lds_window =
make_tile_window(s_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
make_tile_window(s_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
auto p_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
[[maybe_unused]] auto p_lds_window =
make_tile_window(p_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
make_tile_window(p_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
auto o_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
[[maybe_unused]] auto o_lds_window =
make_tile_window(o_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
make_tile_window(o_lds, make_tuple(number<BLOCK_M>{}, number<BLOCK_SIZE>{}), {0, 0});
auto m_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc1D<BLOCK_Q>());
MakeSimpleLdsDesc1D<BLOCK_M>());
[[maybe_unused]] auto m_lds_window =
make_tile_window(m_lds, make_tuple(number<BLOCK_Q>{}), {0});
make_tile_window(m_lds, make_tuple(number<BLOCK_M>{}), {0});
const index_t warp_group_id = get_warp_id() / 4;
@@ -543,7 +539,7 @@ struct UnifiedAttentionPipeline
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{}, number<BLOCK_SIZE>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE);
index_t kv_token_start = seqlen_k_start;
@@ -812,7 +808,7 @@ struct UnifiedAttentionPipeline
gemm_0(sp(sp_reg_idx).sp_compute,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
sequence<BLOCK_M, k0_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.k_tile,
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
@@ -822,7 +818,7 @@ struct UnifiedAttentionPipeline
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
sequence<BLOCK_M, k1_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.v_tile,
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
@@ -836,7 +832,7 @@ struct UnifiedAttentionPipeline
gemm_0(sp(sp_reg_idx).sp_compute,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
sequence<BLOCK_M, k0_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.k_tile,
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
@@ -846,7 +842,7 @@ struct UnifiedAttentionPipeline
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
sequence<BLOCK_M, k1_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.v_tile,
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
@@ -894,7 +890,7 @@ struct UnifiedAttentionPipeline
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
q_origin.at(number<0>{}), kv_token_start, number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
q_origin.at(number<0>{}), kv_token_start, number<BLOCK_M>{}, number<BLOCK_SIZE>{});
if(need_perpixel_check)
{
set_tile_if(sp(sp_reg_idx).sp_compute,
@@ -1209,7 +1205,6 @@ struct UnifiedAttentionPipeline
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
FmhaMask mask,
@@ -1224,7 +1219,6 @@ struct UnifiedAttentionPipeline
identity{},
v_dram_block_window_tmp,
identity{},
num_queries_per_kv,
block_tables_ptr,
block_table_offset,
identity{},