[rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d)

CK][fmha] Add StreamLLM sink support to batch_prefill
 pipeline (#6479)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The existing paged-KV attention pipelines (pagedkv, splitkv) support
  StreamLLM-style sink tokens — a fixed set of initial tokens kept in
  attention alongside the sliding window. The `batch_prefill` pipeline
  (chunked-prefill with VLLM-style block tables) previously hardcoded
  `kHasSink = false`, making it incompatible with sink-based attention
  patterns in LLM serving scenarios.

  This PR extends `batch_prefill` to support `kHasSink` and wires it
into `fmha_fwd_runner` for validation against the existing CPU
reference.

## Technical Details

 **Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`):
- When `kHasSink`, the K/V loop splits into a sink phase [0,
sink_seq_end)
and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv.
  - K advance at the sink→window transition jumps
    `seqlen_k_start - sink_seq_end + kN0` to bridge the gap.
- V scatter-gather offsets are re-initialized at the transition to fix a
window mismatch bug: V was lagging kN0 behind K after the large jump,
    loading from the wrong sequence position.
- Bias window, dropout seq_offset, and mask type (LogitsSinkMask)
updated
    for sink-awareness.

**Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`,
`fmha_batch_prefill.py`):
- `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded
`false`).
- Codegen adds `F_sink` field; skips batch-mode kernels (group mode
required).
  - CMake test filter broadened from 9 → 33 instances covering
    fp16/bf16 × mask/nmask × lse/nlse × sink/nsink.

  **Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`):
  - `fmha_batch_prefill()` dispatched from `run_fwd` when:
    group mode + paged KV + num_splits == 1.
- K/V strides corrected for runner's [num_pages, nhead_k,
page_block_size, hdim] layout.
  - `page_block_size % 128` check relaxed: batch_prefill supports ps=16.
  - CPU reference paged-KV reordering guards extended with
    `CK_TILE_FMHA_FWD_BATCH_PREFILL_API`.

## Test Plan

Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run
  `tile_example_fmha_fwd` in group mode with page_block_size=16.

  Test matrix:
  - Mask: no-mask, causal, sliding window
  - Sink: nsink, sink=1..128
  - dtype: fp16, bf16
  - LSE output: on/off
  - seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024}
  - GQA, chunked prefill, large batch×seqlen
  - page_block_size: 16, 32

## Test Result

171 test cases, all valid:y:
  - nmask + nsink: ✓
  - causal + nsink: ✓
  - causal + sink=8: ✓
  - sliding window + sink=8 (d=128, d=256): ✓
  - bf16, LSE output, GQA: ✓

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-21 11:05:12 +00:00
committed by assistant-librarian[bot]
parent b75afb4274
commit d22aafb48b
7 changed files with 261 additions and 59 deletions

View File

@@ -10,7 +10,7 @@ if(NOT INST_TARGETS)
endif()
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(BUILD_TESTING)
@@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,80,128,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
@@ -174,6 +173,13 @@ else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
endif()
# conditionally enable call to the batch_prefill API in fmha_fwd example and tests
if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1)
else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0)
endif()
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)

View File

@@ -84,6 +84,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_qscale},
{F_occupancy},
false,
{F_sink},
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
@@ -124,7 +125,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
#include <iostream>
@@ -201,9 +202,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -247,6 +248,7 @@ class FmhaFwdApiTrait:
skpad: str
dpad: str
dvpad: str
sink: str # t/f
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
@@ -343,6 +345,7 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_sink: str # t/f (StreamLLM sink tokens)
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@@ -406,6 +409,11 @@ class FmhaFwdPipeline:
else:
n += "_nqscale"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n
@@ -472,6 +480,7 @@ class FmhaFwdApiPool:
trait.kv_lookup_table
],
F_page_size=trait.page_size,
F_sink=BOOL_MAP[trait.sink],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -578,6 +587,7 @@ class FmhaFwdKernel:
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
)
@property
@@ -617,6 +627,7 @@ class FmhaFwdKernel:
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
sink=self.F_pipeline.F_sink,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
@@ -655,6 +666,7 @@ class KernelComponentFactory:
bias,
lse,
dropout,
sink,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
@@ -663,12 +675,13 @@ class KernelComponentFactory:
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
# no need lse/dropout/sink kernels
for (
logits,
qscale,
@@ -684,7 +697,7 @@ class KernelComponentFactory:
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
@@ -701,20 +714,34 @@ class CustomFactory(KernelComponentFactory):
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
targets: Optional[List[str]] = None
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
# non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different
# buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets.
has_non_gfx9 = targets is not None and any(
not t.startswith("gfx9") for t in targets
)
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
if has_non_gfx9:
return api_pool, gen
for dtype in FWD_DTYPE_MAP.keys():
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
# batch_prefill pipeline requires group mode (static_assert in pipeline problem)
if mode != "group":
continue
for tile, pipeline in itertools.product(
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
):
@@ -829,7 +856,7 @@ def write_blobs(
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
@@ -844,7 +871,7 @@ def list_blobs(
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

View File

@@ -1452,6 +1452,7 @@ template <ck_tile::index_t HDim_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false,
ck_tile::index_t kPageBlockSize_ = 1,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
@@ -1480,7 +1481,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
kPadDv_,
kUseTrLoad_,
kSkipMinSeqlenQ_,
false>
kHasSink_>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;

View File

@@ -387,7 +387,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
CK_TILE_FMHA_FWD_PAGEDKV_API))
CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_BATCH_PREFILL_API))
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
@@ -395,7 +395,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
page_block_size = 0;
}
#endif
if(!(page_block_size % 128 == 0))
// batch_prefill supports flexible page sizes (not just multiples of 128)
const bool need_128_aligned_page =
(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API ||
CK_TILE_FMHA_FWD_PAGEDKV_API);
if(need_128_aligned_page && 0 < page_block_size && !(page_block_size % 128 == 0))
{
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
<< std::endl;
@@ -972,9 +976,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
// kvcache or group mode with padding enabled)
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.size() * sizeof(int32_t)
: 0);
// batch_prefill (group+kvcache) also needs per-batch seqlen_k for VLLM_BLOCK_TABLE_2D
const bool need_seqlen_k_buf = (mode == mode_enum::batch && use_kvcache) ||
has_group_k_padding || (mode == mode_enum::group && use_kvcache);
ck_tile::DeviceMem seqlen_k_buf(need_seqlen_k_buf ? seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
@@ -1013,9 +1018,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.data()
: nullptr);
seqlen_k_buf.ToDevice(need_seqlen_k_buf ? seqlen_ks.data() : nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
@@ -1146,6 +1149,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
traits.use_pagedkv = (0 < page_block_size);
}
else if constexpr(std::is_same_v<fmha_batch_prefill_traits,
std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
traits.qscale_type = qscale.type;
traits.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
traits.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
traits.page_size = page_block_size;
}
}
};
@@ -1498,6 +1512,67 @@ fwd_result fmha_fwd_run(mode_enum mode,
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
else if constexpr(std::is_same_v<fmha_batch_prefill_args, std::decay_t<decltype(args)>>)
{
// Fields already set by the outer else block above:
// bias_ptr, lse_ptr, o_ptr, seqlen_k, max_seqlen_q, scale_s,
// logits_soft_cap, stride_bias/o, nhead/batch stride for bias/lse/o,
// window_size_left/right, sink_size, mask_type.
// scale_p/scale_o: batch_prefill-specific fields absent from fmha_fwd_args.
args.scale_p = 1.f;
args.scale_o = 1.f;
// Dropout fields: the outer fmha_fwd_args branch sets these; set them here
// for batch_prefill since it takes a separate inner branch.
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
if(drop_prefs)
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
else
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
// Paged KV: LINEAR_LAYOUT + VLLM_BLOCK_TABLE_2D
// block_table_buf: [batch, max_blocks_per_seq] of physical page ids
// seqlen_k_buf: [batch] of per-batch seqlen_k values
args.num_total_pages = max_num_page_blocks;
args.page_block_size = page_block_size;
args.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
args.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
args.kv_indptr = nullptr;
args.kv_page_indices = block_table_buf.GetDeviceBuffer();
args.kv_last_page_lens = nullptr;
args.seqlen_k_ptr = seqlen_k_buf.GetDeviceBuffer();
args.batch_stride_block_table = batch_stride_block_table;
// group mode required: seqstart_q is prefix-sum of per-batch seqlen_q
args.seqstart_q_ptr = seqstart_q_buf.GetDeviceBuffer();
// batch_prefill LINEAR_LAYOUT strides for runner's K layout
// [max_num_page_blocks, nhead_k, page_block_size, hdim]:
// stride_k = hdim_q (token stride within one head's page slice)
// nhead_stride_k = page_block_size * hdim_q (head stride)
// batch_stride_k = nhead_k * page_block_size * hdim_q (page stride, already set)
args.stride_k = hdim_q;
args.nhead_stride_k = page_block_size * hdim_q;
// V is row-major, same layout convention
args.stride_v = hdim_v;
args.nhead_stride_v = page_block_size * hdim_v;
// descale: not used for fp16/bf16
args.q_descale_ptr = nullptr;
args.k_descale_ptr = nullptr;
args.v_descale_ptr = nullptr;
args.nblock_stride_kv_block_descale = 0;
args.nhead_stride_kv_block_descale = 0;
}
}
};
@@ -1524,6 +1599,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
auto run_fwd = [&](const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_BATCH_PREFILL_API
// batch_prefill: group mode + paged KV, tested against the same CPU reference
if(1 == num_splits && use_kvcache && mode == mode_enum::group)
{
fmha_batch_prefill_traits bp_traits;
init_traits(bp_traits);
fmha_batch_prefill_args bp_args;
init_args(bp_args);
const float ave_time = fmha_batch_prefill(bp_traits, bp_args, sc);
if(ave_time >= 0.0f)
return ave_time;
}
#endif // CK_TILE_FMHA_FWD_BATCH_PREFILL_API
#if CK_TILE_FMHA_FWD_PAGEDKV_API
if(1 == num_splits && use_kvcache)
{
@@ -1844,7 +1934,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
if(0 < page_block_size)
{
// clang-format off
@@ -1895,7 +1986,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
if(0 < page_block_size)
{
if(is_v_rowmajor)