Files
composable_kernel/include
msaffari-amd f347c1324d Add VEC_K_COL_V KV cache layout for FP8 batch_prefill
Introduces a decode-aligned hybrid KV cache layout where K stays 5D
vectorized (matching VECTORIZED_LAYOUT) and V is 4D ColumnMajor
[NumBlocks, NumHeads, HeadDim, PageSize]. This matches the layout
produced by aiter's reshape_and_cache_kernel and consumed by the decode
paged-attention kernel, so the prefill kernel can ingest the live KV
cache without an intermediate reshape.

Changes:
- block_attention_kvcache_layout_enum.hpp: add VEC_K_COL_V_LAYOUT = 2.
- fmha_batch_prefill_kernel.hpp: extend the vectorized K dram branch to
  cover VEC_K_COL_V (K layout identical); add a new V dram branch that
  builds the (NumPages, HeadDim, PageSize) view with strides
  (batch_stride_v, page_block_size, 1) and merges to logical
  (D, TotalSeqK). stride_k_for_pipeline covers both vectorized layouts;
  stride_v_for_pipeline routes through kargs.stride_v (= 1 from the
  wrapper) for VEC_K_COL_V via the LINEAR else branch.
- block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: kAlignmentV
  keeps Policy::GetAlignmentV<Problem>() for VEC_K_COL_V despite
  kPadSeqLenK=true. Pages are always fully populated, so vec loads along
  the contiguous PageSize never cross page boundaries.
- block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp:
  add kUseVectorizedVPolicy<Problem>() predicate and route all V-side
  specializations (GetAlignmentV, GetSmemKPackV,
  GetSingleSmemElementSpaceSize, MakeVLdsBlockDescriptor,
  MakeVDramTileDistribution) through it; VEC_K_COL_V shares the
  VECTORIZED V tile distribution / LDS layout / SmemKPack / alignment.
- block_fmha_pipeline_problem.hpp: introduce kIsKVectorized predicate;
  relax IsVLayoutRowMajor static_assert to accept VEC_K_COL_V_LAYOUT
  (the only layout in which V is ColumnMajor).
- tile_fmha_traits.hpp: extend the batch-prefill KV layout static_assert
  to accept VEC_K_COL_V_LAYOUT.
- fmha_fwd.hpp (example): add is_v_rowmajor field (default true) to
  fmha_batch_prefill_args so the auto-generated dispatcher can pick a
  ColumnMajor V kernel variant when the wrapper requests one.
- codegen/ops/fmha_batch_prefill.py: emit fp8bf16 PER_TOKEN_HEAD
  vlayout="col" variants gated to kv_memory_layout="vec_k_col_v" only;
  relax the receipt 200 filter so vlayout="col" passes through for
  vec_k_col_v only.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-29 15:33:28 +00:00
..