mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
remove if statements
This commit is contained in:
@@ -471,9 +471,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k;
|
||||
long valid_out_elements = 0;
|
||||
|
||||
if(problem.mask.type == mask_enum::no_mask) {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
{
|
||||
valid_out_elements = kv_lens * query_lens;
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
if(query_lens > kv_lens)
|
||||
{
|
||||
valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2;
|
||||
@@ -483,7 +486,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
valid_out_elements =
|
||||
query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2);
|
||||
}
|
||||
|
||||
}
|
||||
// Causal logic for valid output elements
|
||||
|
||||
|
||||
@@ -34,7 +34,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
arg_parser
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
// .insert("b", "3", "batch size")
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
|
||||
.insert("h_k",
|
||||
"8",
|
||||
"num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) +
|
||||
" times this")
|
||||
.insert("s", "3328", "max seqlen_q")
|
||||
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
|
||||
.insert("nb", "1024", "num_blks")
|
||||
@@ -76,11 +79,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
}
|
||||
|
||||
auto seqlen_preprocess(ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t max_seqlen_kv,
|
||||
const std::vector<int>& query_lens_input,
|
||||
const std::vector<int>& kv_lens_input,
|
||||
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t max_seqlen_kv,
|
||||
const std::vector<int>& query_lens_input,
|
||||
const std::vector<int>& kv_lens_input,
|
||||
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
|
||||
{
|
||||
// If both query_lens and kv_lens are provided, return them directly
|
||||
if(!query_lens_input.empty() && !kv_lens_input.empty())
|
||||
@@ -107,11 +110,11 @@ auto seqlen_preprocess(ck_tile::index_t batch,
|
||||
|
||||
query_lens.resize(batch);
|
||||
kv_lens.resize(batch);
|
||||
|
||||
|
||||
for(ck_tile::index_t i = 0; i < batch; ++i)
|
||||
{
|
||||
query_lens[i] = q_dist(gen);
|
||||
kv_lens[i] = kv_dist(gen);
|
||||
kv_lens[i] = kv_dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,31 +134,27 @@ struct Problem
|
||||
nhead_q = nhead_kv * num_queries_per_kv;
|
||||
|
||||
ck_tile::index_t max_seqlen_q = args.get_int("s");
|
||||
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
|
||||
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
|
||||
|
||||
if (max_seqlen_kv == -1) {
|
||||
if(max_seqlen_kv == -1)
|
||||
{
|
||||
max_seqlen_kv = max_seqlen_q;
|
||||
}
|
||||
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
kv_lens = args.get_int_vec("kv_lens");
|
||||
assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b");
|
||||
batch = args.get_int("b");
|
||||
assert(query_lens.size() == kv_lens.size() &&
|
||||
"query_lens and kv_lens must have the same length b");
|
||||
batch = args.get_int("b");
|
||||
page_blk_size = args.get_int("page_blk_size");
|
||||
|
||||
|
||||
bool varlen = args.get_bool("varlen");
|
||||
auto [query_lens_, kv_lens_] = seqlen_preprocess(
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
query_lens,
|
||||
kv_lens,
|
||||
varlen);
|
||||
auto [query_lens_, kv_lens_] =
|
||||
seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen);
|
||||
|
||||
query_lens = query_lens_;
|
||||
kv_lens = kv_lens_;
|
||||
kv_lens = kv_lens_;
|
||||
batch = query_lens.size();
|
||||
|
||||
// Calculate scale_s
|
||||
@@ -164,9 +163,9 @@ struct Problem
|
||||
scale_s = 1.0f / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
|
||||
// Initialize other scales
|
||||
scale = args.get_float("scale");
|
||||
scale_k = args.get_float("scale_k");
|
||||
scale_v = args.get_float("scale_v");
|
||||
scale = args.get_float("scale");
|
||||
scale_k = args.get_float("scale_k");
|
||||
scale_v = args.get_float("scale_v");
|
||||
num_tokens = 0;
|
||||
for(const auto& len : query_lens)
|
||||
{
|
||||
@@ -300,17 +299,12 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
|
||||
-1,
|
||||
0,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
1,
|
||||
false));
|
||||
-1, 0, seqlen_q, seqlen_kv, 1, false));
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
@@ -342,7 +336,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
args.num_seqs = problem.batch;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = num_queries_per_kv;
|
||||
args.page_blk_size = problem.page_blk_size;
|
||||
args.page_blk_size = problem.page_blk_size;
|
||||
args.mask_type = 2;
|
||||
args.hdim = problem.hdim;
|
||||
|
||||
@@ -428,7 +422,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
|
||||
|
||||
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
|
||||
ck_tile::index_t max_num_blocks_per_seq =
|
||||
(max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
|
||||
|
||||
// Create block_tables
|
||||
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
|
||||
@@ -506,20 +501,22 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s
|
||||
<< ", query_lens:[";
|
||||
for (size_t i = 0; i < problem.query_lens.size(); ++i) {
|
||||
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:[";
|
||||
for(size_t i = 0; i < problem.query_lens.size(); ++i)
|
||||
{
|
||||
std::cout << problem.query_lens[i];
|
||||
if (i < problem.query_lens.size() - 1) std::cout << ",";
|
||||
if(i < problem.query_lens.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "], kv_lens:[";
|
||||
for (size_t i = 0; i < problem.kv_lens.size(); ++i) {
|
||||
for(size_t i = 0; i < problem.kv_lens.size(); ++i)
|
||||
{
|
||||
std::cout << problem.kv_lens[i];
|
||||
if (i < problem.kv_lens.size() - 1) std::cout << ",";
|
||||
if(i < problem.kv_lens.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "], mask:" << "causal mask" << std::fixed << ", "
|
||||
<< std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops
|
||||
<< " TFlops, " << std::setprecision(2)
|
||||
std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time
|
||||
<< " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
||||
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
|
||||
|
||||
if(!run_config.verify)
|
||||
@@ -597,37 +594,37 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
|
||||
size_t total = static_cast<size_t>(problem.num_tokens) *
|
||||
static_cast<size_t>(problem.nhead_q) *
|
||||
|
||||
size_t total = static_cast<size_t>(problem.num_tokens) * static_cast<size_t>(problem.nhead_q) *
|
||||
static_cast<size_t>(problem.hdim);
|
||||
|
||||
size_t nonzero = 0;
|
||||
|
||||
for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
for (int d = 0; d < problem.hdim; ++d) {
|
||||
if (static_cast<float>(o(tok, h, d)) != 0.0f) {
|
||||
nonzero++;
|
||||
for(int tok = 0; tok < problem.num_tokens; ++tok)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
if(static_cast<float>(o(tok, h, d)) != 0.0f)
|
||||
{
|
||||
nonzero++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float percent = (total > 0)
|
||||
? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total))
|
||||
: 0.0f;
|
||||
float percent =
|
||||
(total > 0) ? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total)) : 0.0f;
|
||||
|
||||
std::cout << "\nNon-zero elements in output tensor o: "
|
||||
<< nonzero << " / " << total
|
||||
<< " (" << percent << "%)\n";
|
||||
std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " ("
|
||||
<< percent << "%)\n";
|
||||
|
||||
// std::cout << "\n=== Complete Output Tensor (o) ===\n";
|
||||
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
@@ -652,7 +649,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// }
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -30,7 +30,7 @@ struct unified_attention_args
|
||||
index_t num_head_q;
|
||||
index_t num_queries_per_kv;
|
||||
index_t page_blk_size;
|
||||
//index_t BLOCK_SIZE;
|
||||
// index_t BLOCK_SIZE;
|
||||
|
||||
index_t hdim;
|
||||
// TODO window
|
||||
|
||||
@@ -204,7 +204,6 @@ struct UnifiedAttentionKernel
|
||||
return left - 1;
|
||||
}
|
||||
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
@@ -259,13 +258,14 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx;
|
||||
|
||||
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
const index_t q_block_local_idx =
|
||||
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
|
||||
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
|
||||
const index_t cur_batch_query_len =
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
|
||||
// TODO check if we get the block size info from pipeline
|
||||
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
|
||||
@@ -276,11 +276,10 @@ struct UnifiedAttentionKernel
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len =
|
||||
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
+ 1));
|
||||
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
@@ -288,7 +287,8 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
||||
const index_t num_blocks =
|
||||
amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
||||
|
||||
// TODO sliding window
|
||||
const index_t num_blocks_start = 0;
|
||||
@@ -315,7 +315,8 @@ struct UnifiedAttentionKernel
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
|
||||
|
||||
index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
|
||||
index_t query_len_padded =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
@@ -397,21 +398,20 @@ struct UnifiedAttentionKernel
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
-1,
|
||||
-1,
|
||||
0,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false
|
||||
);
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE;
|
||||
assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size
|
||||
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return UnifiedAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
|
||||
@@ -545,9 +545,8 @@ struct UnifiedAttentionPipeline
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
|
||||
const auto num_total_loop = num_blocks;
|
||||
index_t k_block_table_off = num_blocks_start;
|
||||
index_t v_block_table_off = num_blocks_start;
|
||||
|
||||
index_t k_block_idx = 0;
|
||||
index_t v_block_idx = 0;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
@@ -562,12 +561,13 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
|
||||
// TODO check correctness of this
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE;
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
assert(k_block_table_off == v_block_table_off); // because of the following line
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_table_off];
|
||||
assert(k_block_idx == v_block_idx); // because of the following line
|
||||
block_table_offset += num_blocks_start;
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -672,56 +672,35 @@ struct UnifiedAttentionPipeline
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
// Page block index tracking
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
// PAGE_BLOCK_SIZE / BLOCK_SIZE;
|
||||
index_t k_block_i_inside_page = 0;
|
||||
index_t v_block_i_inside_page = 0;
|
||||
// index_t kv_block_idx = 0;
|
||||
// only for block 0 and thread
|
||||
if(blockIdx.x == 0 && threadIdx.x == 0) {}
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
// prefetch next K tile (only if not at the end of loop)
|
||||
if (k_block_table_off * kv_page_size_in_blocks + k_block_i_inside_page + 1 >= num_total_loop)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Update block index inside the page
|
||||
++k_block_i_inside_page;
|
||||
if(k_block_i_inside_page < kv_page_size_in_blocks)
|
||||
{
|
||||
// Staying inside the page, just move the window
|
||||
move_tile_window(k_dram_window, {BLOCK_SIZE, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Moving outside the page, fetch new physical page index
|
||||
k_block_table_off++;
|
||||
index_t k_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + k_block_table_off]);
|
||||
k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0});
|
||||
k_block_i_inside_page = 0;
|
||||
}
|
||||
k_block_idx++;
|
||||
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PAGE_BLOCK_SIZE +
|
||||
(k_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE,
|
||||
0});
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
// prefetch next V tile (only if not at the end of loop)
|
||||
if (v_block_table_off * kv_page_size_in_blocks + v_block_i_inside_page + 1 >= num_total_loop)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Update the block index inside the page
|
||||
++v_block_i_inside_page;
|
||||
if(v_block_i_inside_page < kv_page_size_in_blocks)
|
||||
{
|
||||
// Staying inside the page, just move the window
|
||||
move_tile_window(v_dram_window, {BLOCK_SIZE, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Moving outside the page, fetch new physical page index
|
||||
v_block_table_off++;
|
||||
index_t v_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + v_block_table_off]);
|
||||
v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0});
|
||||
v_block_i_inside_page = 0;
|
||||
}
|
||||
v_block_idx++;
|
||||
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PAGE_BLOCK_SIZE +
|
||||
(v_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE,
|
||||
0});
|
||||
// we assume that v load is always after k
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
|
||||
Reference in New Issue
Block a user