mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Introduces a decode-aligned hybrid KV cache layout where K stays 5D vectorized (matching VECTORIZED_LAYOUT) and V is 4D ColumnMajor [NumBlocks, NumHeads, HeadDim, PageSize]. This matches the layout produced by aiter's reshape_and_cache_kernel and consumed by the decode paged-attention kernel, so the prefill kernel can ingest the live KV cache without an intermediate reshape. Changes: - block_attention_kvcache_layout_enum.hpp: add VEC_K_COL_V_LAYOUT = 2. - fmha_batch_prefill_kernel.hpp: extend the vectorized K dram branch to cover VEC_K_COL_V (K layout identical); add a new V dram branch that builds the (NumPages, HeadDim, PageSize) view with strides (batch_stride_v, page_block_size, 1) and merges to logical (D, TotalSeqK). stride_k_for_pipeline covers both vectorized layouts; stride_v_for_pipeline routes through kargs.stride_v (= 1 from the wrapper) for VEC_K_COL_V via the LINEAR else branch. - block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: kAlignmentV keeps Policy::GetAlignmentV<Problem>() for VEC_K_COL_V despite kPadSeqLenK=true. Pages are always fully populated, so vec loads along the contiguous PageSize never cross page boundaries. - block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp: add kUseVectorizedVPolicy<Problem>() predicate and route all V-side specializations (GetAlignmentV, GetSmemKPackV, GetSingleSmemElementSpaceSize, MakeVLdsBlockDescriptor, MakeVDramTileDistribution) through it; VEC_K_COL_V shares the VECTORIZED V tile distribution / LDS layout / SmemKPack / alignment. - block_fmha_pipeline_problem.hpp: introduce kIsKVectorized predicate; relax IsVLayoutRowMajor static_assert to accept VEC_K_COL_V_LAYOUT (the only layout in which V is ColumnMajor). - tile_fmha_traits.hpp: extend the batch-prefill KV layout static_assert to accept VEC_K_COL_V_LAYOUT. - fmha_fwd.hpp (example): add is_v_rowmajor field (default true) to fmha_batch_prefill_args so the auto-generated dispatcher can pick a ColumnMajor V kernel variant when the wrapper requests one. - codegen/ops/fmha_batch_prefill.py: emit fp8bf16 PER_TOKEN_HEAD vlayout="col" variants gated to kv_memory_layout="vec_k_col_v" only; relax the receipt 200 filter so vlayout="col" passes through for vec_k_col_v only. Co-authored-by: Cursor <cursoragent@cursor.com>