remove if statements

This commit is contained in:
Tianxing Wu
2025-12-11 09:21:55 +00:00
parent 345758971e
commit 73aed1b57c
5 changed files with 106 additions and 128 deletions

View File

@@ -471,9 +471,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k;
long valid_out_elements = 0;
if(problem.mask.type == mask_enum::no_mask) {
if(problem.mask.type == mask_enum::no_mask)
{
valid_out_elements = kv_lens * query_lens;
} else {
}
else
{
if(query_lens > kv_lens)
{
valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2;
@@ -483,7 +486,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
valid_out_elements =
query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2);
}
}
// Causal logic for valid output elements

View File

@@ -34,7 +34,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
arg_parser
.insert("prec", "bf16", "data type. fp16/bf16")
// .insert("b", "3", "batch size")
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
.insert("h_k",
"8",
"num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) +
" times this")
.insert("s", "3328", "max seqlen_q")
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
.insert("nb", "1024", "num_blks")
@@ -76,11 +79,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
}
auto seqlen_preprocess(ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t max_seqlen_kv,
const std::vector<int>& query_lens_input,
const std::vector<int>& kv_lens_input,
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
ck_tile::index_t max_seqlen_q,
ck_tile::index_t max_seqlen_kv,
const std::vector<int>& query_lens_input,
const std::vector<int>& kv_lens_input,
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
{
// If both query_lens and kv_lens are provided, return them directly
if(!query_lens_input.empty() && !kv_lens_input.empty())
@@ -107,11 +110,11 @@ auto seqlen_preprocess(ck_tile::index_t batch,
query_lens.resize(batch);
kv_lens.resize(batch);
for(ck_tile::index_t i = 0; i < batch; ++i)
{
query_lens[i] = q_dist(gen);
kv_lens[i] = kv_dist(gen);
kv_lens[i] = kv_dist(gen);
}
}
@@ -131,31 +134,27 @@ struct Problem
nhead_q = nhead_kv * num_queries_per_kv;
ck_tile::index_t max_seqlen_q = args.get_int("s");
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
if (max_seqlen_kv == -1) {
if(max_seqlen_kv == -1)
{
max_seqlen_kv = max_seqlen_q;
}
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");
kv_lens = args.get_int_vec("kv_lens");
assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b");
batch = args.get_int("b");
assert(query_lens.size() == kv_lens.size() &&
"query_lens and kv_lens must have the same length b");
batch = args.get_int("b");
page_blk_size = args.get_int("page_blk_size");
bool varlen = args.get_bool("varlen");
auto [query_lens_, kv_lens_] = seqlen_preprocess(
batch,
max_seqlen_q,
max_seqlen_kv,
query_lens,
kv_lens,
varlen);
auto [query_lens_, kv_lens_] =
seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen);
query_lens = query_lens_;
kv_lens = kv_lens_;
kv_lens = kv_lens_;
batch = query_lens.size();
// Calculate scale_s
@@ -164,9 +163,9 @@ struct Problem
scale_s = 1.0f / ck_tile::sqrt(static_cast<float>(hdim));
// Initialize other scales
scale = args.get_float("scale");
scale_k = args.get_float("scale_k");
scale_v = args.get_float("scale_v");
scale = args.get_float("scale");
scale_k = args.get_float("scale_k");
scale_v = args.get_float("scale_v");
num_tokens = 0;
for(const auto& len : query_lens)
{
@@ -300,17 +299,12 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
-1,
0,
seqlen_q,
seqlen_kv,
1,
false));
-1, 0, seqlen_q, seqlen_kv, 1, false));
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
@@ -342,7 +336,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.num_seqs = problem.batch;
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = num_queries_per_kv;
args.page_blk_size = problem.page_blk_size;
args.page_blk_size = problem.page_blk_size;
args.mask_type = 2;
args.hdim = problem.hdim;
@@ -428,7 +422,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
ck_tile::index_t max_num_blocks_per_seq =
(max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
// Create block_tables
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
@@ -506,20 +501,22 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
std::cout << "[" << problem.data_type << "|";
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s
<< ", query_lens:[";
for (size_t i = 0; i < problem.query_lens.size(); ++i) {
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:[";
for(size_t i = 0; i < problem.query_lens.size(); ++i)
{
std::cout << problem.query_lens[i];
if (i < problem.query_lens.size() - 1) std::cout << ",";
if(i < problem.query_lens.size() - 1)
std::cout << ",";
}
std::cout << "], kv_lens:[";
for (size_t i = 0; i < problem.kv_lens.size(); ++i) {
for(size_t i = 0; i < problem.kv_lens.size(); ++i)
{
std::cout << problem.kv_lens[i];
if (i < problem.kv_lens.size() - 1) std::cout << ",";
if(i < problem.kv_lens.size() - 1)
std::cout << ",";
}
std::cout << "], mask:" << "causal mask" << std::fixed << ", "
<< std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops, " << std::setprecision(2)
std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time
<< " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
if(!run_config.verify)
@@ -597,37 +594,37 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
size_t total = static_cast<size_t>(problem.num_tokens) *
static_cast<size_t>(problem.nhead_q) *
size_t total = static_cast<size_t>(problem.num_tokens) * static_cast<size_t>(problem.nhead_q) *
static_cast<size_t>(problem.hdim);
size_t nonzero = 0;
for (int tok = 0; tok < problem.num_tokens; ++tok) {
for (int h = 0; h < problem.nhead_q; ++h) {
for (int d = 0; d < problem.hdim; ++d) {
if (static_cast<float>(o(tok, h, d)) != 0.0f) {
nonzero++;
for(int tok = 0; tok < problem.num_tokens; ++tok)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
if(static_cast<float>(o(tok, h, d)) != 0.0f)
{
nonzero++;
}
}
}
}
float percent = (total > 0)
? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total))
: 0.0f;
float percent =
(total > 0) ? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total)) : 0.0f;
std::cout << "\nNon-zero elements in output tensor o: "
<< nonzero << " / " << total
<< " (" << percent << "%)\n";
std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " ("
<< percent << "%)\n";
// std::cout << "\n=== Complete Output Tensor (o) ===\n";
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
@@ -652,7 +649,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
// std::cout << "\n";
// }
// }
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])

View File

@@ -30,7 +30,7 @@ struct unified_attention_args
index_t num_head_q;
index_t num_queries_per_kv;
index_t page_blk_size;
//index_t BLOCK_SIZE;
// index_t BLOCK_SIZE;
index_t hdim;
// TODO window

View File

@@ -204,7 +204,6 @@ struct UnifiedAttentionKernel
return left - 1;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
const Kargs& kargs)
{
@@ -259,13 +258,14 @@ struct UnifiedAttentionKernel
const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx;
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t q_block_local_idx =
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
const index_t cur_batch_query_len =
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
// TODO check if we get the block size info from pipeline
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
@@ -276,11 +276,10 @@ struct UnifiedAttentionKernel
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len =
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
+ 1));
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1));
if(seq_len < _max_seq_prefix_len)
{
@@ -288,7 +287,8 @@ struct UnifiedAttentionKernel
}
const auto max_seq_prefix_len = _max_seq_prefix_len;
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
const index_t num_blocks =
amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
// TODO sliding window
const index_t num_blocks_start = 0;
@@ -315,7 +315,8 @@ struct UnifiedAttentionKernel
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
index_t query_len_padded =
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
// Q/K/V DRAM and DRAM window
@@ -397,21 +398,20 @@ struct UnifiedAttentionKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
-1,
-1,
0,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false
);
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE;
assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size
auto o_acc_tile = [&]() {
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,

View File

@@ -545,9 +545,8 @@ struct UnifiedAttentionPipeline
const auto q_origin = q_dram_window.get_window_origin();
const auto num_total_loop = num_blocks;
index_t k_block_table_off = num_blocks_start;
index_t v_block_table_off = num_blocks_start;
index_t k_block_idx = 0;
index_t v_block_idx = 0;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking)
@@ -562,12 +561,13 @@ struct UnifiedAttentionPipeline
}
// TODO check correctness of this
index_t i_total_loops = num_blocks_start;
index_t i_total_loops = num_blocks_start;
const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
assert(k_block_table_off == v_block_table_off); // because of the following line
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_table_off];
assert(k_block_idx == v_block_idx); // because of the following line
block_table_offset += num_blocks_start;
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -672,56 +672,35 @@ struct UnifiedAttentionPipeline
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
// Page block index tracking
// const index_t kv_page_size_in_blocks =
// const index_t kv_page_size_in_blocks =
// PAGE_BLOCK_SIZE / BLOCK_SIZE;
index_t k_block_i_inside_page = 0;
index_t v_block_i_inside_page = 0;
// index_t kv_block_idx = 0;
// only for block 0 and thread
if(blockIdx.x == 0 && threadIdx.x == 0) {}
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// prefetch next K tile (only if not at the end of loop)
if (k_block_table_off * kv_page_size_in_blocks + k_block_i_inside_page + 1 >= num_total_loop)
{
return;
}
// Update block index inside the page
++k_block_i_inside_page;
if(k_block_i_inside_page < kv_page_size_in_blocks)
{
// Staying inside the page, just move the window
move_tile_window(k_dram_window, {BLOCK_SIZE, 0});
}
else
{
// Moving outside the page, fetch new physical page index
k_block_table_off++;
index_t k_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + k_block_table_off]);
k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0});
k_block_i_inside_page = 0;
}
k_block_idx++;
index_t k_page_blk_idx =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
k_dram_window.set_window_origin(
{k_page_blk_idx * PAGE_BLOCK_SIZE +
(k_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE,
0});
};
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
// prefetch next V tile (only if not at the end of loop)
if (v_block_table_off * kv_page_size_in_blocks + v_block_i_inside_page + 1 >= num_total_loop)
{
return;
}
// Update the block index inside the page
++v_block_i_inside_page;
if(v_block_i_inside_page < kv_page_size_in_blocks)
{
// Staying inside the page, just move the window
move_tile_window(v_dram_window, {BLOCK_SIZE, 0});
}
else
{
// Moving outside the page, fetch new physical page index
v_block_table_off++;
index_t v_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + v_block_table_off]);
v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0});
v_block_i_inside_page = 0;
}
v_block_idx++;
index_t v_page_blk_idx =
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
v_dram_window.set_window_origin(
{v_page_blk_idx * PAGE_BLOCK_SIZE +
(v_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE,
0});
// we assume that v load is always after k
};
auto K_lds_load = [&](auto k_lds_read_idx) {