mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Assert block_size num_queries_per_kv
This commit is contained in:
@@ -25,6 +25,9 @@
|
||||
#include "unified_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
const ck_tile::index_t BLOCK_SIZE = 32;
|
||||
const ck_tile::index_t num_queries_per_kv = 4;
|
||||
|
||||
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -37,7 +40,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
// "if not equal to h, then this is GQA/MQA case")
|
||||
|
||||
.insert("nb", "1024", "num_blks")
|
||||
.insert("bs", "128", "BLOCK_SIZE for kv")
|
||||
.insert("d", "128", "head dim for q & k")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
// TODO scale factors
|
||||
@@ -83,15 +85,14 @@ struct Problem
|
||||
{
|
||||
explicit Problem(const ck_tile::ArgParser& args)
|
||||
{
|
||||
data_type = args.get_str("prec") == "fp16"
|
||||
? ck_tile::unified_attention_args::data_type_enum::fp16
|
||||
: ck_tile::unified_attention_args::data_type_enum::bf16;
|
||||
batch = args.get_int("b");
|
||||
num_blks = args.get_int("nb");
|
||||
BLOCK_SIZE = args.get_int("bs");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
data_type = args.get_str("prec") == "fp16"
|
||||
? ck_tile::unified_attention_args::data_type_enum::fp16
|
||||
: ck_tile::unified_attention_args::data_type_enum::bf16;
|
||||
batch = args.get_int("b");
|
||||
num_blks = args.get_int("nb");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
// TODO: support other GQA/MQA cases than just 4x
|
||||
nhead_q = nhead_kv * 4;
|
||||
nhead_q = nhead_kv * num_queries_per_kv;
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
@@ -130,7 +131,6 @@ struct Problem
|
||||
ck_tile::unified_attention_args::data_type_enum data_type;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t num_blks;
|
||||
ck_tile::index_t BLOCK_SIZE;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_kv;
|
||||
ck_tile::index_t hdim;
|
||||
@@ -302,12 +302,11 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
ck_tile::unified_attention_args args{};
|
||||
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
// args.seqlen_q = problem.seqlen_q;
|
||||
// args.seqlen_k = problem.seqlen_k;
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv;
|
||||
args.num_queries_per_kv = num_queries_per_kv;
|
||||
args.BLOCK_SIZE = BLOCK_SIZE;
|
||||
args.mask_type = 2;
|
||||
args.hdim = problem.hdim;
|
||||
|
||||
@@ -319,7 +318,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.BLOCK_SIZE;
|
||||
args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * BLOCK_SIZE;
|
||||
args.stride_k_cache_1 = problem.hdim * problem.nhead_kv;
|
||||
args.stride_k_cache_2 = problem.hdim;
|
||||
args.stride_k_cache_3 = 1;
|
||||
@@ -393,8 +392,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
|
||||
|
||||
ck_tile::index_t max_num_blocks_per_seq =
|
||||
(max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
|
||||
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// Create block_tables
|
||||
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
|
||||
@@ -426,6 +424,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
run_config.kernel_repeat};
|
||||
|
||||
auto [result, time] = ck_tile::unified_attention(args, stream_config);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
|
||||
@@ -485,18 +484,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) {
|
||||
// kv cache is paged
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
|
||||
ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
|
||||
self(idx) = k(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]);
|
||||
});
|
||||
v_b.ForEach([&](auto& self, auto idx) {
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
|
||||
ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
|
||||
self(idx) = v(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]);
|
||||
});
|
||||
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
@@ -534,12 +533,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
auto [parse_result, args] = parse_cmd_args(argc, argv);
|
||||
|
||||
if(!parse_result)
|
||||
{
|
||||
std::cerr << "failed to parse command line arguments" << std::endl;
|
||||
|
||||
@@ -28,6 +28,7 @@ struct unified_attention_args
|
||||
index_t num_blks;
|
||||
index_t num_head_q;
|
||||
index_t num_queries_per_kv;
|
||||
index_t BLOCK_SIZE;
|
||||
|
||||
index_t hdim;
|
||||
// TODO window
|
||||
|
||||
@@ -62,9 +62,9 @@ struct unified_attention_kernel_traits
|
||||
static constexpr index_t BLOCK_SIZE = 32;
|
||||
static constexpr index_t HEAD_SIZE = 128;
|
||||
|
||||
// TODO please fix this to support also other num_qhead_per_kvhead
|
||||
static constexpr index_t num_qhead_per_kvhead = 4;
|
||||
static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead;
|
||||
// TODO please fix this to support also other num_queries_per_kv
|
||||
static constexpr index_t num_queries_per_kv = 4;
|
||||
static constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
|
||||
|
||||
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, BLOCK_SIZE, HEAD_SIZE>;
|
||||
@@ -120,6 +120,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
index_t BLOCK_Q = Kernel::BLOCK_Q;
|
||||
assert(args.num_queries_per_kv == Kernel::num_queries_per_kv &&
|
||||
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
|
||||
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
|
||||
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
|
||||
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv &&
|
||||
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
|
||||
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
|
||||
|
||||
Reference in New Issue
Block a user