mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
change to BLOCK_M in shape definitions
This commit is contained in:
@@ -58,7 +58,7 @@ struct unified_attention_kernel_traits
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// BLOCK_Q BLOCK_SIZE HEAD_SIZE
|
||||
// BLOCK_M BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<128, 128, 128>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
@@ -109,6 +109,7 @@ struct unified_attention_kernel_traits
|
||||
template <typename Kernel>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
|
||||
@@ -47,7 +47,7 @@ struct TileUnifiedAttentionShape
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
// static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head)
|
||||
static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen
|
||||
static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen
|
||||
|
||||
@@ -268,7 +268,8 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q;
|
||||
static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M;
|
||||
|
||||
static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE;
|
||||
static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE;
|
||||
static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED;
|
||||
@@ -302,12 +303,12 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv)
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
// create another LDS buffer for p
|
||||
return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType),
|
||||
return ck_tile::max(BLOCK_M * HEAD_SIZE_PADDED * sizeof(PDataType),
|
||||
Policy::template GetSmemSize<Problem>() +
|
||||
BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType));
|
||||
BLOCK_M * BLOCK_SIZE * sizeof(PDataType));
|
||||
}
|
||||
|
||||
// for debug only
|
||||
@@ -394,7 +395,7 @@ struct UnifiedAttentionPipeline
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -409,7 +410,7 @@ struct UnifiedAttentionPipeline
|
||||
BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv));
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize());
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
|
||||
|
||||
Reference in New Issue
Block a user