refactor to combine two kernel

This commit is contained in:
Gino Lu
2026-04-22 13:13:37 -04:00
parent c7e6e4f616
commit ab44b83566
14 changed files with 896 additions and 2990 deletions

View File

@@ -88,68 +88,6 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-float-equal
)
# ============================================================================
# Sparge Jenga (64x128 tile)
# ============================================================================
set(SPARGE_JENGA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api sparge_fwd_jenga
--receipt 600
)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS)
add_custom_command(
OUTPUT ${SPARGE_JENGA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Sparge Jenga kernels"
)
message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}")
set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances")
add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARGE_JENGA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp
)
target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# Sparge + Jenga Example executable
set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# VSA Sparse Attention
# ============================================================================
@@ -215,55 +153,6 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-float-equal
)
# ============================================================================
# Sparge VSA (64x128 tile)
# ============================================================================
set(SPARGE_VSA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api sparge_fwd_vsa
--receipt 600
)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Sparge VSA kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS)
add_custom_command(
OUTPUT ${SPARGE_VSA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Sparge VSA kernels"
)
message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}")
set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances")
add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARGE_VSA_GEN_BLOBS}
)
target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
# ============================================================================
@@ -289,16 +178,20 @@ target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
-Wno-float-equal
)
# Sparge + VSA Example executable (now links blockmap kernel too)
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}
${SPARGE_VSA_INSTANCES}
# ----------------------------------------------------------------------------
# Build unified Sparge test: combines blockmap, Jenga, and VSA attention
# for end-to-end evaluation and timing in a single executable.
# ----------------------------------------------------------------------------
set(EXAMPLE_SPARGE "tile_example_sparge")
message(DEBUG "adding example ${EXAMPLE_SPARGE}")
add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
target_link_libraries(${EXAMPLE_SPARGE}
${SPARSE_ATTN_JENGA_INSTANCES}
${SPARSE_ATTN_VSA_INSTANCES}
${SPARGE_BLOCKMAP_INSTANCES}
)
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)

View File

@@ -141,6 +141,17 @@ float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_jenga_fwd_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
"""
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
@@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo
}}
"""
FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp"
FMHA_FWD_ONESHOT_API = """
#include "fmha_fwd_trek.hpp"
#include <iostream>
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type
<< " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v
<< " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k
<< " mask=" << static_cast<int>(t.mask_type) << ")" << std::endl;
}}
"""
FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
fmha_jenga_fwd_oneshot_<trait_>(s, a);
return;
}}
"""
@dataclass
class CppConstraint:
@@ -274,10 +324,7 @@ class FmhaFwdApiTrait:
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
return "true"
@property
def skcheck(self) -> str:
@@ -447,6 +494,67 @@ class FmhaFwdApiPool:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@property
def oneshot_api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
@@ -582,6 +690,27 @@ class KernelComponentFactory:
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
16,
32,
@@ -780,7 +909,7 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue
@@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api)
def write_blobs(
@@ -865,3 +995,4 @@ def list_blobs(
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n")

View File

@@ -141,6 +141,17 @@ float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_vsa_fwd_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
"""
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
@@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo
}}
"""
FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp"
FMHA_FWD_ONESHOT_API = """
#include "fmha_fwd_trek.hpp"
#include <iostream>
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type
<< " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v
<< " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k
<< " mask=" << static_cast<int>(t.mask_type) << ")" << std::endl;
}}
"""
FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
fmha_vsa_fwd_oneshot_<trait_>(s, a);
return;
}}
"""
@dataclass
class CppConstraint:
@@ -274,10 +324,7 @@ class FmhaFwdApiTrait:
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
return "true"
@property
def skcheck(self) -> str:
@@ -447,6 +494,67 @@ class FmhaFwdApiPool:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@property
def oneshot_api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
@@ -582,6 +690,27 @@ class KernelComponentFactory:
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
16,
32,
@@ -780,7 +909,7 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async_vsa":
continue
@@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api)
def write_blobs(
@@ -865,3 +995,4 @@ def list_blobs(
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n")

View File

@@ -1,799 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "kernel/fmha_fwd_jenga_kernel.hpp"
"""
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_jenga_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
return "true"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# jenga fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if pipeline.tag != "qr_async":
continue
k = FmhaFwdKernel(
F_idx=2,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

View File

@@ -1,799 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "kernel/fmha_fwd_vsa_kernel.hpp"
"""
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_vsa_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
return "true"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# vsa fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if pipeline.tag != "qr_async_vsa":
continue
k = FmhaFwdKernel(
F_idx=1,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

View File

@@ -277,13 +277,13 @@ struct fmha_jenga_fwd_traits
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
// sparge jenga
float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
// VSA uses the same traits structure as Jenga; aliases for clarity
template <ck_tile::index_t HDim_,
@@ -325,10 +325,10 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
// sparge vsa
float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);

View File

@@ -1,189 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "jenga_sparge_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"Jenga sparse attention supports fp16/bf16 only.");
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q;
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k;
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_jenga_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_jenga_fwd_args args;
init_args(args);
sparge_jenga_fwd(fmha_traits, args, stream_config);
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
template ck_tile::HostTensor<ck_tile::half_t>
jenga_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
jenga_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -1,27 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <optional>
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);

View File

@@ -61,6 +61,57 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, //
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
// ============================================================================
// bf16: D=128, kM0=64, kN0=128
// ============================================================================
using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
using bmap_bf16_shape =
ck_tile::TileFmhaShape<bmap_bf16_block_tile,
ck_tile::sequence<4, 1, 1>,
ck_tile::sequence<16, 16, 16>,
ck_tile::sequence<4, 1, 1>,
ck_tile::sequence<16, 16, 16>,
true>;
using bmap_bf16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
true, // kPadSeqLenK
true, // kPadHeadDimQ
true, // kPadHeadDimV
false, // kHasLogitsSoftCap
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false, // kStoreLSE
false, // kHasDropout
false, // kHasRandVal
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
-1,
false>;
using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
using bmap_bf16_mask = ck_tile::GenericAttentionMask<false>;
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
ck_tile::bf16_t, // KDataType
ck_tile::bf16_t, // VDataType
float, // SaccDataType
float, // SMPLComputeDataType
ck_tile::bf16_t, // BiasDataType
uint8_t, // RandValOutputDataType
float, // LSEDataType
ck_tile::bf16_t, // PDataType
float, // OaccDataType
ck_tile::bf16_t, // ODataType
bmap_bf16_shape,
false, // kIsGroupMode
bmap_bf16_variant,
bmap_bf16_mask,
false, // kUseTrLoad
bmap_bf16_trait>;
using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_bf16_problem>;
using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel<bmap_bf16_pipeline>;
// ============================================================================
// Dispatch
// ============================================================================
@@ -81,8 +132,96 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits,
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
using k_ = bmap_bf16_kernel;
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
}
if(s.log_level_ > 0)
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
return -1.f;
}
// ============================================================================
// Oneshot version: launches kernel without timing wrapper
// ============================================================================
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
using k_ = bmap_fp16_kernel;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
return;
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
using k_ = bmap_bf16_kernel;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
return;
}
std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
}
// ============================================================================
// Combined functions: blockmap + attention timed together via launch_kernel
// ============================================================================
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
<< std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) {
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
},
[=](const ck_tile::stream_config& s_) {
fmha_jenga_fwd_oneshot(attn_t, attn_a, s_);
});
}
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
<< std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) {
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
},
[=](const ck_tile::stream_config& s_) {
fmha_vsa_fwd_oneshot(attn_t, attn_a, s_);
});
}

View File

@@ -91,3 +91,16 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
// Combined functions: blockmap + attention with unified timing
float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args,
fmha_jenga_fwd_traits, fmha_jenga_fwd_args,
const ck_tile::stream_config&);
float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args,
fmha_vsa_fwd_traits, fmha_vsa_fwd_args,
const ck_tile::stream_config&);

View File

@@ -0,0 +1,432 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA).
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "fmha_fwd_trek.hpp"
#include "sparge_blockmap_trek.hpp"
#include "sparge_tool.hpp"
// ============================================================================
// Helpers
// ============================================================================
template <typename T>
ck_tile::HostTensor<T>
make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm)
{
if(i_perm)
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
for(ck_tile::index_t h = 0; h < nhead; ++h)
for(ck_tile::index_t s = 0; s < seqlen; ++s)
for(ck_tile::index_t d = 0; d < hdim; ++d)
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Arg parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("pipeline", "jenga", "attention pipeline: jenga / vsa")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)")
.insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)")
.insert("simthreshd1", "0.6", "similarity threshold for blockmap")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main test
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
std::string pipeline = arg_parser.get_str("pipeline");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
float topk = arg_parser.get_float("topk");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float simthreshd1 = arg_parser.get_float("simthreshd1");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
if(nhead_k < 0) nhead_k = nhead;
if(seqlen_k < 0) seqlen_k = seqlen_q;
if(hdim_v < 0) hdim_v = hdim_q;
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
if(cdfthreshd >= 0.0f)
topk = -1.0f;
constexpr ck_tile::index_t BLKQ = 64;
constexpr ck_tile::index_t BLKK = 128;
if(hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<\n"
<< "Kernel instances are generated for hdim=128 only.\n";
return true;
}
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::string prec_str = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
std::cout << "[" << pipeline << "|" << prec_str
<< "] b=" << batch << " h=" << nhead << " s=" << seqlen_q
<< " d=" << hdim_q << " topk=" << topk
<< " sim1=" << simthreshd1 << std::flush;
// ---- allocate host tensors ----
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
auto output_host = o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<uint8_t> block_map_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// ---- device tensors ----
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes());
q_dev.ToDevice(q_host.data());
k_dev.ToDevice(k_host.data());
v_dev.ToDevice(v_host.data());
o_dev.SetZero();
block_map_dev.SetZero();
lut_dev.SetZero();
valid_bn_dev.SetZero();
// ---- strides (BHSD when i_perm=true) ----
auto q_strides = q_host.get_strides();
auto k_strides = k_host.get_strides();
auto v_strides = v_host.get_strides();
auto o_strides = output_host.get_strides();
float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// ---- build blockmap args ----
sparge_blockmap_traits bmap_traits;
bmap_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
bmap_traits.hdim_q = hdim_q;
sparge_blockmap_args bmap_args;
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
bmap_args.batch = batch;
bmap_args.seqlen_q = seqlen_q;
bmap_args.seqlen_k = seqlen_k;
bmap_args.hdim_q = hdim_q;
bmap_args.nhead_q = nhead;
bmap_args.nhead_k = nhead_k;
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
bmap_args.batch_stride_q = q_strides[0];
bmap_args.batch_stride_k = k_strides[0];
bmap_args.simthreshd1 = simthreshd1;
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
bmap_args.topk = topk;
bmap_args.scale = scale_s;
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
// ---- build attention args ----
ck_tile::stream_config stream_cfg;
stream_cfg.stream_id_ = nullptr;
stream_cfg.time_kernel_ = true;
stream_cfg.log_level_ = kname;
stream_cfg.cold_niters_ = warmup;
stream_cfg.nrepeat_ = repeat;
float avg_ms = -1.0f;
if(pipeline == "jenga")
{
fmha_jenga_fwd_traits attn_traits;
attn_traits.hdim_q = hdim_q;
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
fmha_jenga_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else if(pipeline == "vsa")
{
fmha_vsa_fwd_traits attn_traits;
attn_traits.hdim_q = hdim_q;
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
fmha_vsa_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else
{
std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n";
return false;
}
// ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ----
std::size_t flop = static_cast<std::size_t>(batch) * nhead *
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
float tflops = (avg_ms > 0.f) ? static_cast<float>(flop) / 1.E9f / avg_ms : 0.f;
if(avg_ms > 0.f)
{
std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, "
<< std::setprecision(2) << tflops << " TFlops" << std::flush;
}
// ---- copy results back ----
o_dev.FromDevice(output_host.data());
block_map_dev.FromDevice(block_map_host.data());
// ---- count active blocks ----
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
ck_tile::index_t active_blocks = 0;
for(size_t i = 0; i < block_map_host.mData.size(); ++i)
if(block_map_host.mData[i])
active_blocks++;
float actual_sparsity = 1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity
<< "(" << active_blocks << "/" << total_blocks << ")" << std::flush;
// ---- validation ----
bool pass = true;
if(do_validation)
{
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
if(diff > atol && rel_diff > rtol)
num_errors++;
}
pass = (num_errors == 0);
std::cout << ", " << (pass ? "PASS" : "FAIL")
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
<< " maxdiff=" << max_diff << ")";
}
std::cout << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments\n";
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << "\n";
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -1,422 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Demo: Sparge block-map -> Jenga sparse attention
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparge_attention.h"
#include "sparge_tool.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
// Sparge-specific
.insert("blkq", "64", "Sparge BLKQ")
.insert("blkk", "128", "Sparge BLKK")
.insert("simthreshd1", "0.6", "Sparge sim threshold")
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Sparge params
ck_tile::index_t blkq = arg_parser.get_int("blkq");
ck_tile::index_t blkk = arg_parser.get_int("blkk");
float simthreshd1 = arg_parser.get_float("simthreshd1");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float topk = arg_parser.get_float("topk");
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, "
"hdim_q=128, hdim_v=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
ck_tile::index_t BLKQ = blkq;
ck_tile::index_t BLKK = blkk;
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
<< ", topk=" << topk << ")" << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Build block map using Sparge tool
std::cout << "Building Sparge block map..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
// Print actual sparsity
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_relation_onehot(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
jenga_sparge_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
for(int i = 0; i < warmup; ++i)
{
jenga_sparge_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
jenga_sparge_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
std::size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -1,597 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device)
#include <iostream>
#include <cmath>
#include <string>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "sparge_blockmap_trek.hpp"
#include "fmha_fwd_trek.hpp"
#include "sparge_tool.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
// Sparge-specific
.insert("blkq", "64", "Sparge BLKQ")
.insert("blkk", "128", "Sparge BLKK")
.insert("simthreshd1", "0.6", "Sparge sim threshold")
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Sparge params
ck_tile::index_t blkq = arg_parser.get_int("blkq");
ck_tile::index_t blkk = arg_parser.get_int("blkk");
float simthreshd1 = arg_parser.get_float("simthreshd1");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float topk = arg_parser.get_float("topk");
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, "
"hdim_q=128, hdim_v=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
ck_tile::index_t BLKQ = blkq;
ck_tile::index_t BLKK = blkk;
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
<< ", topk=" << topk << ")" << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors and fill with random data
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// ==================================================================
// Allocate device memory once, HtoD once
// ==================================================================
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
const std::size_t bmap_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t);
const std::size_t lut_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t);
const std::size_t valid_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * sizeof(int32_t);
ck_tile::DeviceMem bmap_buf(bmap_bytes);
ck_tile::DeviceMem lut_buf(lut_bytes);
ck_tile::DeviceMem valid_buf(valid_bytes);
bmap_buf.SetZero();
lut_buf.SetZero();
valid_buf.SetZero();
// ==================================================================
// Common stride calculations
// ==================================================================
assert(nhead % nhead_k == 0);
const float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q;
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v;
const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v;
const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v;
const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v;
const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k;
const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v;
std::string data_type = "fp16";
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
data_type = "bf16";
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
// ==================================================================
// GPU: Build block map + VSA LUT (always run, device-only)
// ==================================================================
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
{
sparge_blockmap_args args;
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
args.hdim_q = hdim_q;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.simthreshd1 = simthreshd1;
args.cdfthreshd = cdfthreshd;
args.topk = topk;
args.scale = scale_s;
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
sparge_blockmap_traits traits;
traits.data_type = data_type;
traits.hdim_q = hdim_q;
sparge_blockmap_fwd(traits, args, ck_tile::stream_config{});
}
// ==================================================================
// VSA sparse attention kernel (always run, LUT stays on device)
// ==================================================================
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
fmha_vsa_fwd_args fmha_args;
fmha_args.q_ptr = q_buf.GetDeviceBuffer();
fmha_args.k_ptr = k_buf.GetDeviceBuffer();
fmha_args.v_ptr = v_buf.GetDeviceBuffer();
fmha_args.lut_ptr = lut_buf.GetDeviceBuffer();
fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
fmha_args.o_ptr = o_buf.GetDeviceBuffer();
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen_q;
fmha_args.seqlen_k = seqlen_k;
fmha_args.max_seqlen_q = seqlen_q;
fmha_args.hdim_q = hdim_q;
fmha_args.hdim_v = hdim_v;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead_k;
fmha_args.scale_s = scale_s;
fmha_args.stride_q = stride_q;
fmha_args.stride_k = stride_k;
fmha_args.stride_v = stride_v;
fmha_args.stride_o = stride_o;
fmha_args.nhead_stride_q = nhead_stride_q;
fmha_args.nhead_stride_k = nhead_stride_k;
fmha_args.nhead_stride_v = nhead_stride_v;
fmha_args.nhead_stride_o = nhead_stride_o;
fmha_args.batch_stride_q = batch_stride_q;
fmha_args.batch_stride_k = batch_stride_k;
fmha_args.batch_stride_v = batch_stride_v;
fmha_args.batch_stride_o = batch_stride_o;
fmha_args.window_size_left = mask.left;
fmha_args.window_size_right = mask.right;
fmha_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
fmha_vsa_fwd_traits fmha_traits;
fmha_traits.hdim_q = hdim_q;
fmha_traits.hdim_v = hdim_v;
fmha_traits.data_type = data_type;
fmha_traits.is_v_rowmajor = true;
fmha_traits.mask_type = mask.type;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ kname ? 1 : 0,
warmup,
repeat,
false};
float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config);
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
// DtoH: attention output (always needed)
o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes());
// DtoH: block_map (needed for sparsity stats and validation)
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes);
// ==================================================================
// Sparsity statistics (pure CPU, reads block_map HostTensor)
// ==================================================================
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_map_gpu(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// ==================================================================
// Validation (only when -v=1)
// ==================================================================
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
// CPU golden: block map + VSA LUT
std::cout << "Building Sparge block map (CPU golden)..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
// DtoH: LUT + valid_block_num (only for validation)
sparge::VSALut vsa_lut_gpu{
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks, num_k_blocks}),
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks}),
};
lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes);
valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes);
// Validate block map
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
{
std::size_t bmap_mismatches = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb))
{
bmap_mismatches++;
if(bmap_mismatches <= 10)
{
std::cout
<< " block_map mismatch at [" << b << "," << h << "," << qb
<< "," << kb << "]: GPU="
<< static_cast<int>(block_map_gpu(b, h, qb, kb)) << " CPU="
<< static_cast<int>(block_relation_onehot(b, h, qb, kb))
<< std::endl;
}
}
}
}
}
}
std::cout << " Block map mismatches: " << bmap_mismatches << " / "
<< (batch * nhead * num_q_blocks * num_k_blocks) << std::endl;
if(bmap_mismatches > 0)
{
std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl;
pass = false;
}
else
{
std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl;
}
}
// Validate VSA LUT
std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl;
{
std::size_t lut_mismatches = 0;
std::size_t valid_mismatches = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
if(vsa_lut_gpu.valid_block_num(b, h, qb) !=
vsa_lut_cpu.valid_block_num(b, h, qb))
{
valid_mismatches++;
if(valid_mismatches <= 5)
{
std::cout << " valid_block_num mismatch at [" << b << "," << h
<< "," << qb
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
<< std::endl;
}
}
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb))
{
lut_mismatches++;
if(lut_mismatches <= 10)
{
std::cout
<< " LUT mismatch at [" << b << "," << h << "," << qb
<< "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl;
}
}
}
}
}
}
std::cout << " LUT mismatches: " << lut_mismatches << std::endl;
std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl;
if(lut_mismatches == 0 && valid_mismatches == 0)
{
std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
// Validate attention output
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "\nComputing reference attention output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
std::size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nAttention validation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}