change to BLOCK_M in shape definitions

This commit is contained in:
Juuso Korhonen
2025-10-23 08:11:55 +00:00
parent f72b994b00
commit e144872308
3 changed files with 10 additions and 8 deletions

View File

@@ -58,7 +58,7 @@ struct unified_attention_kernel_traits
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
// BLOCK_Q BLOCK_SIZE HEAD_SIZE
// BLOCK_M BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<128, 128, 128>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
using unified_attention_block_warps = sequence<8, 1, 1>;
@@ -109,6 +109,7 @@ struct unified_attention_kernel_traits
template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
{
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,

View File

@@ -47,7 +47,7 @@ struct TileUnifiedAttentionShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along q seqlen
// static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head)
static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen
static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen

View File

@@ -268,7 +268,8 @@ struct UnifiedAttentionPipeline
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q;
static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M;
static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED;
@@ -302,12 +303,12 @@ struct UnifiedAttentionPipeline
}
}();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv)
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// create another LDS buffer for p
return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType),
return ck_tile::max(BLOCK_M * HEAD_SIZE_PADDED * sizeof(PDataType),
Policy::template GetSmemSize<Problem>() +
BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType));
BLOCK_M * BLOCK_SIZE * sizeof(PDataType));
}
// for debug only
@@ -394,7 +395,7 @@ struct UnifiedAttentionPipeline
void* smem_ptr) const
{
using namespace ck_tile;
index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -409,7 +410,7 @@ struct UnifiedAttentionPipeline
BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv));
static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());