mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
sparse_attn: R25 Step 1 A1 — per-warp PV-skip (paper Algorithm 1) + V0 instantiation
Preserve the R25 Step 1 "A1 / redesign D" state before redesigning toward "B"
(per-CTA PV-skip matching upstream shipped reference). This snapshot lets us
restore A1 if the B redesign fails.
A1 redesign D pipeline (per-warp, arithmetic-only PV-skip, wrapped in
`if constexpr (kEnablePVSkip)`):
- include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp
- include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp
V0 instantiation wiring (per gino_tmp/R25/programmer/v0_instance/REPORT.md):
- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py
- example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
- example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp
- example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp
- example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
- example/ck_tile/50_sparse_attn/CMakeLists.txt
- example/ck_tile/01_fmha/CMakeLists.txt
- example/ck_tile/50_sparse_attn/test_sparge.cpp (-pv_skip_compile=0|1 CLI)
This commit excludes all *_REVIEW.{hpp,cpp} mirror files (left untracked) and
all build artefacts. _vsa.hpp / _jenga.hpp are not modified.
Tag: R25-step1-A1-paper-aligned points at this commit.
This commit is contained in:
@@ -200,6 +200,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
# not using add_example_executable() to add this target, since we don't want this to be included in
|
||||
@@ -207,6 +208,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
|
||||
@@ -83,6 +83,7 @@ message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set_property(TARGET ${EXAMPLE_JENGA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
@@ -148,11 +149,64 @@ message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set_property(TARGET ${EXAMPLE_VSA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge Sparse Attention (PV-skip enabled, derived from VSA)
|
||||
# ============================================================================
|
||||
set(SPARSE_ATTN_SPARGE_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd_sparge
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
# Generate list of Sparge kernels (at configure time, only list)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Sparge kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt SPARSE_ATTN_SPARGE_GEN_BLOBS)
|
||||
|
||||
# Generate Sparge kernel source files at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARSE_ATTN_SPARGE_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Sparge Sparse Attention kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Sparge kernel files to be generated: ${SPARSE_ATTN_SPARGE_GEN_BLOBS}")
|
||||
|
||||
# Sparge Instances
|
||||
set(SPARSE_ATTN_SPARGE_INSTANCES "tile_sparse_attn_sparge_instances")
|
||||
|
||||
add_library(${SPARSE_ATTN_SPARGE_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARSE_ATTN_SPARGE_GEN_BLOBS}
|
||||
)
|
||||
target_include_directories(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARSE_ATTN_SPARGE_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARSE_ATTN_SPARGE_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
|
||||
# ============================================================================
|
||||
@@ -188,9 +242,11 @@ add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE}
|
||||
${SPARSE_ATTN_JENGA_INSTANCES}
|
||||
${SPARSE_ATTN_VSA_INSTANCES}
|
||||
${SPARSE_ATTN_SPARGE_INSTANCES}
|
||||
${SPARGE_BLOCKMAP_INSTANCES}
|
||||
)
|
||||
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set_property(TARGET ${EXAMPLE_SPARGE} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
|
||||
@@ -58,11 +58,13 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
|
||||
"qr_async_sparge": "ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_async_sparge": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
|
||||
1041
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py
Normal file
1041
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -334,3 +334,141 @@ template <typename Traits_>
|
||||
void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// sparge: same args as vsa plus a scalar PV-skip threshold (Step 1).
|
||||
struct fmha_sparge_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
|
||||
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
// R25 V0: select between kEnablePVSkip=true / =false template instantiations
|
||||
// at host dispatch time. Default true preserves existing behaviour (binary
|
||||
// shipped pre-R25-V0 only had the true instance). Profiler can flip this to
|
||||
// false to measure the source-equivalent-to-VSA baseline (`if constexpr`
|
||||
// removes the entire PV-skip AST).
|
||||
bool pv_skip_compile = true;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.lut_ptr,
|
||||
args.valid_block_num_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.pv_threshold,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_>
|
||||
using fmha_sparge_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_>;
|
||||
|
||||
using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits;
|
||||
|
||||
float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// R25 V0: kEnablePVSkip is now a template non-type param so the codegen can
|
||||
// emit both true / false instantiations from the same source tree. The host
|
||||
// dispatch (fmha_sparge_fwd_api.cpp) selects the right specialization based
|
||||
// on fmha_sparge_fwd_args::pv_skip_compile at runtime.
|
||||
template <typename Traits_, bool kEnablePVSkip>
|
||||
float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args);
|
||||
|
||||
template <typename Traits_, bool kEnablePVSkip>
|
||||
void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args);
|
||||
|
||||
void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits,
|
||||
fmha_sparge_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -264,3 +264,24 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t,
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); });
|
||||
}
|
||||
|
||||
float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t,
|
||||
sparge_blockmap_args bmap_a,
|
||||
fmha_sparge_fwd_traits attn_t,
|
||||
fmha_sparge_fwd_args attn_a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) { fmha_sparge_fwd_oneshot(attn_t, attn_a, s_); });
|
||||
}
|
||||
|
||||
@@ -169,3 +169,9 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits,
|
||||
fmha_vsa_fwd_traits,
|
||||
fmha_vsa_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
float sparge_sparge_fwd_combined(sparge_blockmap_traits,
|
||||
sparge_blockmap_args,
|
||||
fmha_sparge_fwd_traits,
|
||||
fmha_sparge_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
@@ -100,7 +102,19 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
.insert("kname", "0", "print kernel name")
|
||||
.insert("dump_o",
|
||||
"",
|
||||
"if non-empty, dump raw output buffer bytes to this path (for bit-identical "
|
||||
"baseline comparison)")
|
||||
.insert("pv_threshold",
|
||||
"1e30",
|
||||
"SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip")
|
||||
.insert("pv_skip_compile",
|
||||
"1",
|
||||
"R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use "
|
||||
"kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to "
|
||||
"VSA baseline)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -130,6 +144,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
std::string dump_o_path = arg_parser.get_str("dump_o");
|
||||
float pv_threshold = arg_parser.get_float("pv_threshold");
|
||||
int pv_skip_compile = arg_parser.get_int("pv_skip_compile");
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
@@ -309,7 +326,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else if(pipeline == "vsa")
|
||||
{
|
||||
fmha_vsa_fwd_traits attn_traits;
|
||||
// R25: -pipeline=vsa now dispatches to the sparge pipeline family that adds
|
||||
// SpargeAttn §4.4 PV-skip; pass pv_threshold (+1e30 disables skip, matches old vsa).
|
||||
fmha_sparge_fwd_traits attn_traits;
|
||||
attn_traits.hdim_q = hdim_q;
|
||||
attn_traits.hdim_v = hdim_v;
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
@@ -317,7 +336,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_vsa_fwd_args attn_args;
|
||||
fmha_sparge_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
@@ -333,6 +352,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.pv_threshold = pv_threshold;
|
||||
attn_args.pv_skip_compile = (pv_skip_compile != 0);
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
@@ -350,7 +371,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_args.mask_type = 0;
|
||||
|
||||
avg_ms =
|
||||
sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -374,6 +395,23 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
o_dev.FromDevice(output_host.data());
|
||||
block_map_dev.FromDevice(block_map_host.data());
|
||||
|
||||
// ---- optional raw output dump (for bit-identical baseline comparison) ----
|
||||
if(!dump_o_path.empty())
|
||||
{
|
||||
std::ofstream ofs(dump_o_path, std::ios::binary | std::ios::trunc);
|
||||
if(!ofs)
|
||||
{
|
||||
std::cerr << "\n [dump_o] failed to open " << dump_o_path << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
ofs.write(reinterpret_cast<const char*>(output_host.data()),
|
||||
static_cast<std::streamsize>(output_host.get_element_space_size_in_bytes()));
|
||||
std::cout << "\n [dump_o] wrote " << output_host.get_element_space_size_in_bytes()
|
||||
<< " bytes to " << dump_o_path;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- count active blocks ----
|
||||
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
|
||||
@@ -0,0 +1,442 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_, bool kEnablePVSkip_ = true>
|
||||
struct FmhaFwdSpargeKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr bool kDoFp8StaticQuant =
|
||||
(QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE);
|
||||
static_assert(!FmhaPipeline::kIsGroupMode, "Sparge sparse attention supports batch mode only.");
|
||||
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
|
||||
"Sparge sparse attention does not support bias.");
|
||||
static_assert(!kStoreLSE, "Sparge sparse attention does not support LSE output.");
|
||||
static_assert(!kHasDropout, "Sparge sparse attention does not support dropout.");
|
||||
static_assert(!kHasLogitsSoftCap, "Sparge sparse attention does not support logits soft-cap.");
|
||||
static_assert(!kDoFp8StaticQuant,
|
||||
"Sparge sparse attention does not support FP8 static quantization yet.");
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct FmhaFwdCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* lut_ptr;
|
||||
const void* valid_block_num_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
float scale_s;
|
||||
float pv_threshold;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
};
|
||||
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
using Kargs = FmhaFwdBatchModeKargs;
|
||||
|
||||
struct BlockIndices
|
||||
{
|
||||
ck_tile::index_t batch_idx;
|
||||
ck_tile::index_t qo_head_idx;
|
||||
ck_tile::index_t kv_head_idx;
|
||||
};
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* lut_ptr,
|
||||
const void* valid_block_num_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float pv_threshold,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lut_ptr,
|
||||
valid_block_num_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale_s,
|
||||
#endif
|
||||
pv_threshold,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o}, // FmhaFwdCommonKargs
|
||||
{}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1>
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
// sparse mask
|
||||
const int* lut_ptr =
|
||||
reinterpret_cast<const int*>(kargs.lut_ptr) +
|
||||
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
|
||||
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) +
|
||||
i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
|
||||
const int* valid_block_num_ptr =
|
||||
reinterpret_cast<const int*>(kargs.valid_block_num_ptr) +
|
||||
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
|
||||
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) +
|
||||
i_tile_m;
|
||||
const int valid_block_num_value = valid_block_num_ptr[0];
|
||||
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
return make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{});
|
||||
else
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
else
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lut_ptr,
|
||||
valid_block_num_value,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.pv_threshold,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr);
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_o, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
o_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,698 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA;
|
||||
// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original
|
||||
// _vsa.hpp can remain frozen as an A/B baseline.
|
||||
//
|
||||
// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs;
|
||||
// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim.
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
|
||||
bool kEnablePVSkip_ = true>
|
||||
struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
{
|
||||
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
|
||||
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
|
||||
Problem::kPadHeadDimV == true);
|
||||
static constexpr bool kPadSeqLenQ = true;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
|
||||
"VSA sparse attention does not support bias.");
|
||||
static_assert(!kHasDropout, "VSA sparse attention does not support dropout.");
|
||||
static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output.");
|
||||
static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap.");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && FmhaMask::IsMasking)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
if constexpr(kPadSeqLenK)
|
||||
return 2;
|
||||
else
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
if constexpr(kPadSeqLenK)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
if constexpr(kPadSeqLenK)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const int* kv_block_idx_ptr,
|
||||
int kv_blocks,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
float pv_threshold, // SpargeAttn PV-skip threshold; see §2 of pv_skip plan
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
if constexpr(!kEnablePVSkip)
|
||||
{
|
||||
(void)pv_threshold; // silence unused-param when PV-skip is compiled out
|
||||
}
|
||||
// R25 Step 1 redesign D: PV-skip control is a compile-time gate
|
||||
// (kEnablePVSkip). The entire PV-skip logic block below is wrapped in
|
||||
// `if constexpr (kEnablePVSkip)`, so when this template parameter is
|
||||
// false the AST contains no vote, no scalar gate, no extra LDS, and
|
||||
// codegen converges with _vsa.hpp's FmhaFwdVSAKernel.
|
||||
//
|
||||
// Runtime fast-path (C3-lite): pv_threshold == +1e30 sentinel disables
|
||||
// the skip at runtime via one scalar branch (sgpr); kept inside the
|
||||
// `if constexpr` so the OFF instantiation pays zero cost.
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
|
||||
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
|
||||
{0, 0, 0});
|
||||
},
|
||||
number<Policy::NumKVLdsBuffers>{});
|
||||
|
||||
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_load =
|
||||
make_tile_window(k_lds_Load_view,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
int seqlen_k_start = kv_block_idx_ptr[0] * kN0;
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
q_dram_window.init_raw();
|
||||
|
||||
// TODO: we use async Copy for K, which is inline asm
|
||||
// a side effect is we have to use inline asm for q as well
|
||||
auto q = decltype(load_tile(q_dram_window)){};
|
||||
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
|
||||
// however, q would be cleared in the constructor of static distributed tensor
|
||||
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
|
||||
load_tile_raw(q, q_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto num_total_loop = kv_blocks;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
|
||||
// otherwise will have compute error(maybe compiler bug?)
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
return o_acc;
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dram_window.init_raw();
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = bool_constant<false>{};
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
|
||||
buffer_load_fence(k_dram_window.get_num_of_access());
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
async_load_fence(k_dram_window.get_num_of_access());
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(
|
||||
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: this to fix a bug when loop smaller than 2,
|
||||
// the following fence/barrier will be scheduled inside 1st loop
|
||||
if constexpr(k0_loops <= 2)
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
async_load_fence();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int block_idx = kv_block_idx_ptr[i_total_loops + 1];
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
s_acc,
|
||||
get_slice_tile(
|
||||
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// STAGE 2, scale_s, mask, softmax (no bias/soft-cap)
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
|
||||
// Only needed when K tail and V use the same LDS buffer
|
||||
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
// store & prefetch next v, after the max reduction.
|
||||
// R25 Step 1 redesign D: V→LDS store and the next-V DRAM load are
|
||||
// UNCONDITIONAL — per-warp PV-skip cannot gate them (cross-warp
|
||||
// shared LDS state; see Researcher audit A.3/A.4/A.5).
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(v_lds_window_tmp, v_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp, v_buf);
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// ================================================================
|
||||
// PV-SKIP per Q-tile (SpargeAttn paper §4.4)
|
||||
// R25 Step 1 redesign D — per-warp arithmetic-only:
|
||||
// Compile-time `if constexpr (kEnablePVSkip)` wraps the entire
|
||||
// block. When kEnablePVSkip=false the AST has zero PV-skip
|
||||
// artifacts → codegen converges with _vsa.hpp.
|
||||
//
|
||||
// When enabled, a per-warp predicate gates ONLY the per-row,
|
||||
// VGPR-private work (exp2 → p_compute, rowsum, `l += rowsum_p`).
|
||||
// V load / V→LDS store / gemm_1 / every `s_barrier` /
|
||||
// `block_sync_lds` stay unconditional (cross-warp LDS dep — see
|
||||
// Researcher audit A.7).
|
||||
//
|
||||
// On warp_skip, this warp's owned rows of p_compute are zeroed
|
||||
// so the unconditional gemm_1 contributes 0 to o_acc (audit
|
||||
// A.7 "simplest realisation"). The alpha-rescale `l *= tmp` and
|
||||
// `o *= tmp` still apply.
|
||||
//
|
||||
// pv_threshold semantics shift: now per-warp max diff (slightly
|
||||
// more aggressive than per-block at the same threshold; matches
|
||||
// upstream SpargeAttn `kPerWarp` mode default).
|
||||
//
|
||||
// Skip iff: scale_s * (m_local - m_old) + pv_threshold <= 0
|
||||
// (where m_local/m_old are warp-uniform after block_tile_reduce_sync)
|
||||
// ================================================================
|
||||
// Per-warp PV-skip predicate. Only declared when kEnablePVSkip;
|
||||
// wrapped in a lambda so the false instantiation contains nothing.
|
||||
auto compute_warp_skip = [&]() {
|
||||
if constexpr(kEnablePVSkip)
|
||||
{
|
||||
// C3-lite scalar fast-path: pv_threshold == +1e30 sentinel
|
||||
// disables skip; runtime cost is a single sgpr branch.
|
||||
if(pv_threshold >= 1e29f)
|
||||
return false;
|
||||
// Per-row predicate: warp-AND over rows this warp owns.
|
||||
int warp_skip_int = 1;
|
||||
constexpr auto m_spans = decltype(m_local)::get_distributed_spans();
|
||||
sweep_tile_span(m_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const float diff = scale_s * (static_cast<float>(m_local[i_idx]) -
|
||||
static_cast<float>(m_old[i_idx]));
|
||||
if(!(diff + pv_threshold <= 0.0f))
|
||||
warp_skip_int = 0;
|
||||
});
|
||||
// Warp-level AND reduce (wave=64 on gfx942; xor butterfly).
|
||||
// No LDS, no s_barrier, no cross-warp dependency.
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 32);
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 16);
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 8);
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 4);
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 2);
|
||||
warp_skip_int &= __shfl_xor(warp_skip_int, 1);
|
||||
return warp_skip_int != 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
const bool warp_skip = compute_warp_skip();
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
// exp2 → p_compute and rowsum_p.
|
||||
// R25 redesign D: when kEnablePVSkip + warp_skip, we zero this
|
||||
// warp's owned rows of p_compute so the unconditional gemm_1
|
||||
// contributes zero to o_acc, and skip the rowsum.
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if constexpr(kEnablePVSkip)
|
||||
{
|
||||
if(warp_skip)
|
||||
{
|
||||
p_compute(i_j_idx) = SMPLComputeDataType{0};
|
||||
return;
|
||||
}
|
||||
}
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// l{j}, Oacc{j}: alpha rescale of l / o always runs.
|
||||
// When warp_skip, rowsum_p is already 0 for this
|
||||
// warp's owned rows (p_compute zeroed above), so
|
||||
// `l += rowsum_p` is a no-op — no extra branch needed.
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
const auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(p_compute);
|
||||
else
|
||||
return cast_tile<PDataType>(p_compute);
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm — always runs (block-wide LDS dep; per-warp
|
||||
// skipping has been absorbed by zeroing p_compute rows above).
|
||||
{
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(v_dram_window,
|
||||
number<-1>{},
|
||||
bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1,
|
||||
kK1>{});
|
||||
store_tile(v_lds_window_tmp, v_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1,
|
||||
kK1>{});
|
||||
store_tile(v_lds_window_tmp, v_buf);
|
||||
}
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
}
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
// V load runs unconditionally under redesign D, so no skip
|
||||
// compensation needed (same offset arithmetic as _vsa.hpp).
|
||||
move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)});
|
||||
move_tile_window(k_dram_block_window, {kN0 * block_idx, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail — gemm_1 runs unconditionally under redesign D.
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user