sparse_attn: R25 Step 1 A1 — per-warp PV-skip (paper Algorithm 1) + V0 instantiation

Preserve the R25 Step 1 "A1 / redesign D" state before redesigning toward "B"
(per-CTA PV-skip matching upstream shipped reference). This snapshot lets us
restore A1 if the B redesign fails.

A1 redesign D pipeline (per-warp, arithmetic-only PV-skip, wrapped in
`if constexpr (kEnablePVSkip)`):
  - include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp
  - include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp

V0 instantiation wiring (per gino_tmp/R25/programmer/v0_instance/REPORT.md):
  - example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py
  - example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
  - example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp
  - example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp
  - example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
  - example/ck_tile/50_sparse_attn/CMakeLists.txt
  - example/ck_tile/01_fmha/CMakeLists.txt
  - example/ck_tile/50_sparse_attn/test_sparge.cpp (-pv_skip_compile=0|1 CLI)

This commit excludes all *_REVIEW.{hpp,cpp} mirror files (left untracked) and
all build artefacts. _vsa.hpp / _jenga.hpp are not modified.

Tag: R25-step1-A1-paper-aligned points at this commit.
This commit is contained in:
Gino Lu
2026-05-18 06:13:38 -04:00
parent 840b8a37d9
commit 0f8b58ac88
10 changed files with 2448 additions and 4 deletions

View File

@@ -200,6 +200,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
@@ -207,6 +208,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long

View File

@@ -83,6 +83,7 @@ message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_JENGA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
@@ -148,11 +149,64 @@ message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_VSA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# Sparge Sparse Attention (PV-skip enabled, derived from VSA)
# ============================================================================
set(SPARSE_ATTN_SPARGE_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_sparge
--receipt 600
)
# Generate list of Sparge kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Sparge kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt SPARSE_ATTN_SPARGE_GEN_BLOBS)
# Generate Sparge kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_SPARGE_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Sparge Sparse Attention kernels"
)
message(STATUS "Sparge kernel files to be generated: ${SPARSE_ATTN_SPARGE_GEN_BLOBS}")
# Sparge Instances
set(SPARSE_ATTN_SPARGE_INSTANCES "tile_sparse_attn_sparge_instances")
add_library(${SPARSE_ATTN_SPARGE_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_SPARGE_GEN_BLOBS}
)
target_include_directories(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_SPARGE_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_SPARGE_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
# ============================================================================
@@ -188,9 +242,11 @@ add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
target_link_libraries(${EXAMPLE_SPARGE}
${SPARSE_ATTN_JENGA_INSTANCES}
${SPARSE_ATTN_VSA_INSTANCES}
${SPARSE_ATTN_SPARGE_INSTANCES}
${SPARGE_BLOCKMAP_INSTANCES}
)
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_SPARGE} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal

View File

@@ -58,11 +58,13 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
"qr_async_sparge": "ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge",
}
PIPELINE_ENUM_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_sparge": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {

File diff suppressed because it is too large Load Diff

View File

@@ -334,3 +334,141 @@ template <typename Traits_>
void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
// sparge: same args as vsa plus a scalar PV-skip threshold (Step 1).
struct fmha_sparge_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// R25 V0: select between kEnablePVSkip=true / =false template instantiations
// at host dispatch time. Default true preserves existing behaviour (binary
// shipped pre-R25-V0 only had the true instance). Profiler can flip this to
// false to measure the source-equivalent-to-VSA baseline (`if constexpr`
// removes the entire PV-skip AST).
bool pv_skip_compile = true;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.lut_ptr,
args.valid_block_num_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.pv_threshold,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_>
using fmha_sparge_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
DataType_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_>;
using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits;
float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&);
// R25 V0: kEnablePVSkip is now a template non-type param so the codegen can
// emit both true / false instantiations from the same source tree. The host
// dispatch (fmha_sparge_fwd_api.cpp) selects the right specialization based
// on fmha_sparge_fwd_args::pv_skip_compile at runtime.
template <typename Traits_, bool kEnablePVSkip>
float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args);
template <typename Traits_, bool kEnablePVSkip>
void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args);
void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits,
fmha_sparge_fwd_args,
const ck_tile::stream_config&);

View File

@@ -264,3 +264,24 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t,
},
[=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); });
}
float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t,
sparge_blockmap_args bmap_a,
fmha_sparge_fwd_traits attn_t,
fmha_sparge_fwd_args attn_a,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
<< std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
[=](const ck_tile::stream_config& s_) {
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
},
[=](const ck_tile::stream_config& s_) { fmha_sparge_fwd_oneshot(attn_t, attn_a, s_); });
}

View File

@@ -169,3 +169,9 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits,
fmha_vsa_fwd_traits,
fmha_vsa_fwd_args,
const ck_tile::stream_config&);
float sparge_sparge_fwd_combined(sparge_blockmap_traits,
sparge_blockmap_args,
fmha_sparge_fwd_traits,
fmha_sparge_fwd_args,
const ck_tile::stream_config&);

View File

@@ -5,6 +5,8 @@
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <random>
@@ -100,7 +102,19 @@ auto create_args(int argc, char* argv[])
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
.insert("kname", "0", "print kernel name")
.insert("dump_o",
"",
"if non-empty, dump raw output buffer bytes to this path (for bit-identical "
"baseline comparison)")
.insert("pv_threshold",
"1e30",
"SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip")
.insert("pv_skip_compile",
"1",
"R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use "
"kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to "
"VSA baseline)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -130,6 +144,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
std::string dump_o_path = arg_parser.get_str("dump_o");
float pv_threshold = arg_parser.get_float("pv_threshold");
int pv_skip_compile = arg_parser.get_int("pv_skip_compile");
if(nhead_k < 0)
nhead_k = nhead;
@@ -309,7 +326,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
}
else if(pipeline == "vsa")
{
fmha_vsa_fwd_traits attn_traits;
// R25: -pipeline=vsa now dispatches to the sparge pipeline family that adds
// SpargeAttn §4.4 PV-skip; pass pv_threshold (+1e30 disables skip, matches old vsa).
fmha_sparge_fwd_traits attn_traits;
attn_traits.hdim_q = hdim_q;
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
@@ -317,7 +336,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.bm0 = BLKQ;
fmha_vsa_fwd_args attn_args;
fmha_sparge_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
@@ -333,6 +352,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.pv_threshold = pv_threshold;
attn_args.pv_skip_compile = (pv_skip_compile != 0);
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
@@ -350,7 +371,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_args.mask_type = 0;
avg_ms =
sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else
{
@@ -374,6 +395,23 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
o_dev.FromDevice(output_host.data());
block_map_dev.FromDevice(block_map_host.data());
// ---- optional raw output dump (for bit-identical baseline comparison) ----
if(!dump_o_path.empty())
{
std::ofstream ofs(dump_o_path, std::ios::binary | std::ios::trunc);
if(!ofs)
{
std::cerr << "\n [dump_o] failed to open " << dump_o_path << std::endl;
}
else
{
ofs.write(reinterpret_cast<const char*>(output_host.data()),
static_cast<std::streamsize>(output_host.get_element_space_size_in_bytes()));
std::cout << "\n [dump_o] wrote " << output_host.get_element_space_size_in_bytes()
<< " bytes to " << dump_o_path;
}
}
// ---- count active blocks ----
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
ck_tile::index_t active_blocks = 0;

View File

@@ -0,0 +1,442 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_, bool kEnablePVSkip_ = true>
struct FmhaFwdSpargeKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr bool kDoFp8StaticQuant =
(QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE);
static_assert(!FmhaPipeline::kIsGroupMode, "Sparge sparse attention supports batch mode only.");
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"Sparge sparse attention does not support bias.");
static_assert(!kStoreLSE, "Sparge sparse attention does not support LSE output.");
static_assert(!kHasDropout, "Sparge sparse attention does not support dropout.");
static_assert(!kHasLogitsSoftCap, "Sparge sparse attention does not support logits soft-cap.");
static_assert(!kDoFp8StaticQuant,
"Sparge sparse attention does not support FP8 static quantization yet.");
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaFwdCommonKargs
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* lut_ptr;
const void* valid_block_num_ptr;
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
float scale_s;
float pv_threshold;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdMaskKargs
{
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
};
using Kargs = FmhaFwdBatchModeKargs;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
// std::variant<> can't take in a list initializer, overload for backward compatibility
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* lut_ptr,
const void* valid_block_num_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float pv_threshold,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
lut_ptr,
valid_block_num_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
pv_threshold,
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // FmhaFwdCommonKargs
{}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1>
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_o = 0;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
// sparse mask
const int* lut_ptr =
reinterpret_cast<const int*>(kargs.lut_ptr) +
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) +
i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
const int* valid_block_num_ptr =
reinterpret_cast<const int*>(kargs.valid_block_num_ptr) +
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) +
i_tile_m;
const int valid_block_num_value = valid_block_num_ptr[0];
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
AttentionVariant variant;
const auto variant_params = ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lut_ptr,
valid_block_num_value,
mask,
kargs.scale_s,
kargs.pv_threshold,
variant,
variant_params,
block_indices,
smem_ptr);
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,698 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA;
// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original
// _vsa.hpp can remain frozen as an A/B baseline.
//
// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs;
// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim.
template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
bool kEnablePVSkip_ = true>
struct BlockFmhaPipelineQRKSVSAsyncSparge
{
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"VSA sparse attention does not support bias.");
static_assert(!kHasDropout, "VSA sparse attention does not support dropout.");
static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output.");
static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap.");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(kQKHeaddim <= 32)
{
if constexpr(kPadSeqLenK && FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
if constexpr(kPadSeqLenK)
return 2;
else
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const int* kv_block_idx_ptr,
int kv_blocks,
FmhaMask mask,
float scale_s,
float pv_threshold, // SpargeAttn PV-skip threshold; see §2 of pv_skip plan
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
if constexpr(!kEnablePVSkip)
{
(void)pv_threshold; // silence unused-param when PV-skip is compiled out
}
// R25 Step 1 redesign D: PV-skip control is a compile-time gate
// (kEnablePVSkip). The entire PV-skip logic block below is wrapped in
// `if constexpr (kEnablePVSkip)`, so when this template parameter is
// false the AST contains no vote, no scalar gate, no extra LDS, and
// codegen converges with _vsa.hpp's FmhaFwdVSAKernel.
//
// Runtime fast-path (C3-lite): pv_threshold == +1e30 sentinel disables
// the skip at runtime via one scalar branch (sgpr); kept inside the
// `if constexpr` so the OFF instantiation pays zero cost.
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
int seqlen_k_start = kv_block_idx_ptr[0] * kN0;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto num_total_loop = kv_blocks;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<false>{};
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
// buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access());
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
int block_idx = kv_block_idx_ptr[i_total_loops + 1];
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, mask, softmax (no bias/soft-cap)
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
// Only needed when K tail and V use the same LDS buffer
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
{
__builtin_amdgcn_s_barrier();
}
// store & prefetch next v, after the max reduction.
// R25 Step 1 redesign D: V→LDS store and the next-V DRAM load are
// UNCONDITIONAL — per-warp PV-skip cannot gate them (cross-warp
// shared LDS state; see Researcher audit A.3/A.4/A.5).
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_shuffle_tmp);
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_buf);
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
// ================================================================
// PV-SKIP per Q-tile (SpargeAttn paper §4.4)
// R25 Step 1 redesign D — per-warp arithmetic-only:
// Compile-time `if constexpr (kEnablePVSkip)` wraps the entire
// block. When kEnablePVSkip=false the AST has zero PV-skip
// artifacts → codegen converges with _vsa.hpp.
//
// When enabled, a per-warp predicate gates ONLY the per-row,
// VGPR-private work (exp2 → p_compute, rowsum, `l += rowsum_p`).
// V load / V→LDS store / gemm_1 / every `s_barrier` /
// `block_sync_lds` stay unconditional (cross-warp LDS dep — see
// Researcher audit A.7).
//
// On warp_skip, this warp's owned rows of p_compute are zeroed
// so the unconditional gemm_1 contributes 0 to o_acc (audit
// A.7 "simplest realisation"). The alpha-rescale `l *= tmp` and
// `o *= tmp` still apply.
//
// pv_threshold semantics shift: now per-warp max diff (slightly
// more aggressive than per-block at the same threshold; matches
// upstream SpargeAttn `kPerWarp` mode default).
//
// Skip iff: scale_s * (m_local - m_old) + pv_threshold <= 0
// (where m_local/m_old are warp-uniform after block_tile_reduce_sync)
// ================================================================
// Per-warp PV-skip predicate. Only declared when kEnablePVSkip;
// wrapped in a lambda so the false instantiation contains nothing.
auto compute_warp_skip = [&]() {
if constexpr(kEnablePVSkip)
{
// C3-lite scalar fast-path: pv_threshold == +1e30 sentinel
// disables skip; runtime cost is a single sgpr branch.
if(pv_threshold >= 1e29f)
return false;
// Per-row predicate: warp-AND over rows this warp owns.
int warp_skip_int = 1;
constexpr auto m_spans = decltype(m_local)::get_distributed_spans();
sweep_tile_span(m_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const float diff = scale_s * (static_cast<float>(m_local[i_idx]) -
static_cast<float>(m_old[i_idx]));
if(!(diff + pv_threshold <= 0.0f))
warp_skip_int = 0;
});
// Warp-level AND reduce (wave=64 on gfx942; xor butterfly).
// No LDS, no s_barrier, no cross-warp dependency.
warp_skip_int &= __shfl_xor(warp_skip_int, 32);
warp_skip_int &= __shfl_xor(warp_skip_int, 16);
warp_skip_int &= __shfl_xor(warp_skip_int, 8);
warp_skip_int &= __shfl_xor(warp_skip_int, 4);
warp_skip_int &= __shfl_xor(warp_skip_int, 2);
warp_skip_int &= __shfl_xor(warp_skip_int, 1);
return warp_skip_int != 0;
}
else
{
return false;
}
};
const bool warp_skip = compute_warp_skip();
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
// exp2 → p_compute and rowsum_p.
// R25 redesign D: when kEnablePVSkip + warp_skip, we zero this
// warp's owned rows of p_compute so the unconditional gemm_1
// contributes zero to o_acc, and skip the rowsum.
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(kEnablePVSkip)
{
if(warp_skip)
{
p_compute(i_j_idx) = SMPLComputeDataType{0};
return;
}
}
#if CK_TILE_FMHA_FWD_FAST_EXP2
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}: alpha rescale of l / o always runs.
// When warp_skip, rowsum_p is already 0 for this
// warp's owned rows (p_compute zeroed above), so
// `l += rowsum_p` is a no-op — no extra branch needed.
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(p_compute);
else
return cast_tile<PDataType>(p_compute);
}();
// STAGE 3, KV gemm — always runs (block-wide LDS dep; per-warp
// skipping has been absorbed by zeroing p_compute rows above).
{
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window,
number<-1>{},
bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1,
kK1>{});
store_tile(v_lds_window_tmp, v_shuffle_tmp);
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1,
kK1>{});
store_tile(v_lds_window_tmp, v_buf);
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// V load runs unconditionally under redesign D, so no skip
// compensation needed (same offset arithmetic as _vsa.hpp).
move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)});
move_tile_window(k_dram_block_window, {kN0 * block_idx, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail — gemm_1 runs unconditionally under redesign D.
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
return o_acc;
}
};
} // namespace ck_tile