mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
refactor to combine two kernel
This commit is contained in:
@@ -88,68 +88,6 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge Jenga (64x128 tile)
|
||||
# ============================================================================
|
||||
set(SPARGE_JENGA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api sparge_fwd_jenga
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARGE_JENGA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Sparge Jenga kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}")
|
||||
|
||||
set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances")
|
||||
|
||||
add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARGE_JENGA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + Jenga Example executable
|
||||
set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# VSA Sparse Attention
|
||||
# ============================================================================
|
||||
@@ -215,55 +153,6 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge VSA (64x128 tile)
|
||||
# ============================================================================
|
||||
set(SPARGE_VSA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api sparge_fwd_vsa
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Sparge VSA kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARGE_VSA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Sparge VSA kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}")
|
||||
|
||||
set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances")
|
||||
|
||||
add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARGE_VSA_GEN_BLOBS}
|
||||
)
|
||||
target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
|
||||
# ============================================================================
|
||||
@@ -289,16 +178,20 @@ target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + VSA Example executable (now links blockmap kernel too)
|
||||
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}
|
||||
${SPARGE_VSA_INSTANCES}
|
||||
# ----------------------------------------------------------------------------
|
||||
# Build unified Sparge test: combines blockmap, Jenga, and VSA attention
|
||||
# for end-to-end evaluation and timing in a single executable.
|
||||
# ----------------------------------------------------------------------------
|
||||
set(EXAMPLE_SPARGE "tile_example_sparge")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE}")
|
||||
add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE}
|
||||
${SPARSE_ATTN_JENGA_INSTANCES}
|
||||
${SPARSE_ATTN_VSA_INSTANCES}
|
||||
${SPARGE_BLOCKMAP_INSTANCES}
|
||||
)
|
||||
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
|
||||
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
@@ -141,6 +141,17 @@ float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
void fmha_jenga_fwd_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
|
||||
@@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp"
|
||||
FMHA_FWD_ONESHOT_API = """
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include <iostream>
|
||||
|
||||
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type
|
||||
<< " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v
|
||||
<< " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k
|
||||
<< " mask=" << static_cast<int>(t.mask_type) << ")" << std::endl;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
fmha_jenga_fwd_oneshot_<trait_>(s, a);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
@@ -274,10 +324,7 @@ class FmhaFwdApiTrait:
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
@@ -447,6 +494,67 @@ class FmhaFwdApiPool:
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
@property
|
||||
def oneshot_api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
@@ -582,6 +690,27 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
@@ -780,7 +909,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
@@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
@@ -865,3 +995,4 @@ def list_blobs(
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -141,6 +141,17 @@ float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
void fmha_vsa_fwd_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
|
||||
@@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp"
|
||||
FMHA_FWD_ONESHOT_API = """
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include <iostream>
|
||||
|
||||
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type
|
||||
<< " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v
|
||||
<< " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k
|
||||
<< " mask=" << static_cast<int>(t.mask_type) << ")" << std::endl;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
fmha_vsa_fwd_oneshot_<trait_>(s, a);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
@@ -274,10 +324,7 @@ class FmhaFwdApiTrait:
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
@@ -447,6 +494,67 @@ class FmhaFwdApiPool:
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
@property
|
||||
def oneshot_api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
@@ -582,6 +690,27 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
@@ -780,7 +909,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
@@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
@@ -865,3 +995,4 @@ def list_blobs(
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -1,799 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
|
||||
#include "kernel/fmha_fwd_jenga_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_jenga_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# jenga fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=2,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
@@ -1,799 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
|
||||
#include "kernel/fmha_fwd_vsa_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_vsa_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# vsa fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=1,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
@@ -277,13 +277,13 @@ struct fmha_jenga_fwd_traits
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// sparge jenga
|
||||
float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
template <typename Traits_>
|
||||
void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// VSA uses the same traits structure as Jenga; aliases for clarity
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -325,10 +325,10 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// sparge vsa
|
||||
float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
template <typename Traits_>
|
||||
void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparge_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"Jenga sparse attention supports fp16/bf16 only.");
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_jenga_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_jenga_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
sparge_jenga_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
jenga_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
jenga_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
@@ -61,6 +61,57 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, //
|
||||
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
|
||||
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// bf16: D=128, kM0=64, kN0=128
|
||||
// ============================================================================
|
||||
|
||||
using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
|
||||
|
||||
using bmap_bf16_shape =
|
||||
ck_tile::TileFmhaShape<bmap_bf16_block_tile,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
ck_tile::sequence<16, 16, 16>,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
ck_tile::sequence<16, 16, 16>,
|
||||
true>;
|
||||
|
||||
using bmap_bf16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
|
||||
true, // kPadSeqLenK
|
||||
true, // kPadHeadDimQ
|
||||
true, // kPadHeadDimV
|
||||
false, // kHasLogitsSoftCap
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false, // kStoreLSE
|
||||
false, // kHasDropout
|
||||
false, // kHasRandVal
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
|
||||
-1,
|
||||
false>;
|
||||
|
||||
using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using bmap_bf16_mask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
|
||||
ck_tile::bf16_t, // KDataType
|
||||
ck_tile::bf16_t, // VDataType
|
||||
float, // SaccDataType
|
||||
float, // SMPLComputeDataType
|
||||
ck_tile::bf16_t, // BiasDataType
|
||||
uint8_t, // RandValOutputDataType
|
||||
float, // LSEDataType
|
||||
ck_tile::bf16_t, // PDataType
|
||||
float, // OaccDataType
|
||||
ck_tile::bf16_t, // ODataType
|
||||
bmap_bf16_shape,
|
||||
false, // kIsGroupMode
|
||||
bmap_bf16_variant,
|
||||
bmap_bf16_mask,
|
||||
false, // kUseTrLoad
|
||||
bmap_bf16_trait>;
|
||||
|
||||
using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_bf16_problem>;
|
||||
using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel<bmap_bf16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// Dispatch
|
||||
// ============================================================================
|
||||
@@ -81,8 +132,96 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_bf16_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
|
||||
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
return -1.f;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Oneshot version: launches kernel without timing wrapper
|
||||
// ============================================================================
|
||||
|
||||
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_fp16_kernel;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
return;
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_bf16_kernel;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
return;
|
||||
}
|
||||
|
||||
std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type
|
||||
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Combined functions: blockmap + attention timed together via launch_kernel
|
||||
// ============================================================================
|
||||
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
|
||||
fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_jenga_fwd_oneshot(attn_t, attn_a, s_);
|
||||
});
|
||||
}
|
||||
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
|
||||
fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_vsa_fwd_oneshot(attn_t, attn_a, s_);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -91,3 +91,16 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
|
||||
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
// Combined functions: blockmap + attention with unified timing
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args,
|
||||
fmha_jenga_fwd_traits, fmha_jenga_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args,
|
||||
fmha_vsa_fwd_traits, fmha_vsa_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
432
example/ck_tile/50_sparse_attn/test_sparge.cpp
Normal file
432
example/ck_tile/50_sparse_attn/test_sparge.cpp
Normal file
@@ -0,0 +1,432 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA).
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T>
|
||||
make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Arg parser
|
||||
// ============================================================================
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("pipeline", "jenga", "attention pipeline: jenga / vsa")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)")
|
||||
.insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)")
|
||||
.insert("simthreshd1", "0.6", "similarity threshold for blockmap")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main test
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
std::string pipeline = arg_parser.get_str("pipeline");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
if(nhead_k < 0) nhead_k = nhead;
|
||||
if(seqlen_k < 0) seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0) hdim_v = hdim_q;
|
||||
|
||||
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
|
||||
if(cdfthreshd >= 0.0f)
|
||||
topk = -1.0f;
|
||||
|
||||
constexpr ck_tile::index_t BLKQ = 64;
|
||||
constexpr ck_tile::index_t BLKK = 128;
|
||||
|
||||
if(hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<\n"
|
||||
<< "Kernel instances are generated for hdim=128 only.\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::string prec_str = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
std::cout << "[" << pipeline << "|" << prec_str
|
||||
<< "] b=" << batch << " h=" << nhead << " s=" << seqlen_q
|
||||
<< " d=" << hdim_q << " topk=" << topk
|
||||
<< " sim1=" << simthreshd1 << std::flush;
|
||||
|
||||
// ---- allocate host tensors ----
|
||||
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
auto output_host = o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_map_host({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
|
||||
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// ---- device tensors ----
|
||||
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_dev.ToDevice(q_host.data());
|
||||
k_dev.ToDevice(k_host.data());
|
||||
v_dev.ToDevice(v_host.data());
|
||||
o_dev.SetZero();
|
||||
block_map_dev.SetZero();
|
||||
lut_dev.SetZero();
|
||||
valid_bn_dev.SetZero();
|
||||
|
||||
// ---- strides (BHSD when i_perm=true) ----
|
||||
auto q_strides = q_host.get_strides();
|
||||
auto k_strides = k_host.get_strides();
|
||||
auto v_strides = v_host.get_strides();
|
||||
auto o_strides = output_host.get_strides();
|
||||
|
||||
float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
// ---- build blockmap args ----
|
||||
sparge_blockmap_traits bmap_traits;
|
||||
bmap_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
bmap_traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_args bmap_args;
|
||||
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
bmap_args.batch = batch;
|
||||
bmap_args.seqlen_q = seqlen_q;
|
||||
bmap_args.seqlen_k = seqlen_k;
|
||||
bmap_args.hdim_q = hdim_q;
|
||||
bmap_args.nhead_q = nhead;
|
||||
bmap_args.nhead_k = nhead_k;
|
||||
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
bmap_args.batch_stride_q = q_strides[0];
|
||||
bmap_args.batch_stride_k = k_strides[0];
|
||||
bmap_args.simthreshd1 = simthreshd1;
|
||||
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
|
||||
bmap_args.topk = topk;
|
||||
bmap_args.scale = scale_s;
|
||||
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
// ---- build attention args ----
|
||||
ck_tile::stream_config stream_cfg;
|
||||
stream_cfg.stream_id_ = nullptr;
|
||||
stream_cfg.time_kernel_ = true;
|
||||
stream_cfg.log_level_ = kname;
|
||||
stream_cfg.cold_niters_ = warmup;
|
||||
stream_cfg.nrepeat_ = repeat;
|
||||
|
||||
float avg_ms = -1.0f;
|
||||
|
||||
if(pipeline == "jenga")
|
||||
{
|
||||
fmha_jenga_fwd_traits attn_traits;
|
||||
attn_traits.hdim_q = hdim_q;
|
||||
attn_traits.hdim_v = hdim_v;
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
|
||||
fmha_jenga_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer();
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
|
||||
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
else if(pipeline == "vsa")
|
||||
{
|
||||
fmha_vsa_fwd_traits attn_traits;
|
||||
attn_traits.hdim_q = hdim_q;
|
||||
attn_traits.hdim_v = hdim_v;
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
|
||||
fmha_vsa_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
|
||||
attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer();
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
|
||||
avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
// ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ----
|
||||
std::size_t flop = static_cast<std::size_t>(batch) * nhead *
|
||||
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
|
||||
float tflops = (avg_ms > 0.f) ? static_cast<float>(flop) / 1.E9f / avg_ms : 0.f;
|
||||
|
||||
if(avg_ms > 0.f)
|
||||
{
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops" << std::flush;
|
||||
}
|
||||
|
||||
// ---- copy results back ----
|
||||
o_dev.FromDevice(output_host.data());
|
||||
block_map_dev.FromDevice(block_map_host.data());
|
||||
|
||||
// ---- count active blocks ----
|
||||
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
for(size_t i = 0; i < block_map_host.mData.size(); ++i)
|
||||
if(block_map_host.mData[i])
|
||||
active_blocks++;
|
||||
float actual_sparsity = 1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity
|
||||
<< "(" << active_blocks << "/" << total_blocks << ")" << std::flush;
|
||||
|
||||
// ---- validation ----
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s);
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
num_errors++;
|
||||
}
|
||||
|
||||
pass = (num_errors == 0);
|
||||
std::cout << ", " << (pass ? "PASS" : "FAIL")
|
||||
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
|
||||
<< " maxdiff=" << max_diff << ")";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
@@ -1,422 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Demo: Sparge block-map -> Jenga sparse attention
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparge_attention.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "64", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Sparge params
|
||||
ck_tile::index_t blkq = arg_parser.get_int("blkq");
|
||||
ck_tile::index_t blkk = arg_parser.get_int("blkk");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
ck_tile::index_t BLKQ = blkq;
|
||||
ck_tile::index_t BLKK = blkk;
|
||||
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
|
||||
<< ", topk=" << topk << ")" << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Build block map using Sparge tool
|
||||
std::cout << "Building Sparge block map..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
// Print actual sparsity
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_relation_onehot(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
std::size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
@@ -1,597 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device)
|
||||
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "64", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Sparge params
|
||||
ck_tile::index_t blkq = arg_parser.get_int("blkq");
|
||||
ck_tile::index_t blkk = arg_parser.get_int("blkk");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
ck_tile::index_t BLKQ = blkq;
|
||||
ck_tile::index_t BLKK = blkk;
|
||||
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
|
||||
<< ", topk=" << topk << ")" << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors and fill with random data
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// ==================================================================
|
||||
// Allocate device memory once, HtoD once
|
||||
// ==================================================================
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
|
||||
const std::size_t bmap_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t);
|
||||
const std::size_t lut_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t);
|
||||
const std::size_t valid_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * sizeof(int32_t);
|
||||
|
||||
ck_tile::DeviceMem bmap_buf(bmap_bytes);
|
||||
ck_tile::DeviceMem lut_buf(lut_bytes);
|
||||
ck_tile::DeviceMem valid_buf(valid_bytes);
|
||||
bmap_buf.SetZero();
|
||||
lut_buf.SetZero();
|
||||
valid_buf.SetZero();
|
||||
|
||||
// ==================================================================
|
||||
// Common stride calculations
|
||||
// ==================================================================
|
||||
assert(nhead % nhead_k == 0);
|
||||
const float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q;
|
||||
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
|
||||
const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v;
|
||||
const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v;
|
||||
const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v;
|
||||
const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q;
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v;
|
||||
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
data_type = "bf16";
|
||||
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
// ==================================================================
|
||||
// GPU: Build block map + VSA LUT (always run, device-only)
|
||||
// ==================================================================
|
||||
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
|
||||
{
|
||||
sparge_blockmap_args args;
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.seqlen_q = seqlen_q;
|
||||
args.seqlen_k = seqlen_k;
|
||||
args.hdim_q = hdim_q;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.simthreshd1 = simthreshd1;
|
||||
args.cdfthreshd = cdfthreshd;
|
||||
args.topk = topk;
|
||||
args.scale = scale_s;
|
||||
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
|
||||
sparge_blockmap_traits traits;
|
||||
traits.data_type = data_type;
|
||||
traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_fwd(traits, args, ck_tile::stream_config{});
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// VSA sparse attention kernel (always run, LUT stays on device)
|
||||
// ==================================================================
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
fmha_vsa_fwd_args fmha_args;
|
||||
fmha_args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
fmha_args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
fmha_args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
fmha_args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
fmha_args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen_q;
|
||||
fmha_args.seqlen_k = seqlen_k;
|
||||
fmha_args.max_seqlen_q = seqlen_q;
|
||||
fmha_args.hdim_q = hdim_q;
|
||||
fmha_args.hdim_v = hdim_v;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead_k;
|
||||
fmha_args.scale_s = scale_s;
|
||||
fmha_args.stride_q = stride_q;
|
||||
fmha_args.stride_k = stride_k;
|
||||
fmha_args.stride_v = stride_v;
|
||||
fmha_args.stride_o = stride_o;
|
||||
fmha_args.nhead_stride_q = nhead_stride_q;
|
||||
fmha_args.nhead_stride_k = nhead_stride_k;
|
||||
fmha_args.nhead_stride_v = nhead_stride_v;
|
||||
fmha_args.nhead_stride_o = nhead_stride_o;
|
||||
fmha_args.batch_stride_q = batch_stride_q;
|
||||
fmha_args.batch_stride_k = batch_stride_k;
|
||||
fmha_args.batch_stride_v = batch_stride_v;
|
||||
fmha_args.batch_stride_o = batch_stride_o;
|
||||
fmha_args.window_size_left = mask.left;
|
||||
fmha_args.window_size_right = mask.right;
|
||||
fmha_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
fmha_traits.hdim_q = hdim_q;
|
||||
fmha_traits.hdim_v = hdim_v;
|
||||
fmha_traits.data_type = data_type;
|
||||
fmha_traits.is_v_rowmajor = true;
|
||||
fmha_traits.mask_type = mask.type;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ kname ? 1 : 0,
|
||||
warmup,
|
||||
repeat,
|
||||
false};
|
||||
|
||||
float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
|
||||
// DtoH: attention output (always needed)
|
||||
o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes());
|
||||
|
||||
// DtoH: block_map (needed for sparsity stats and validation)
|
||||
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes);
|
||||
|
||||
// ==================================================================
|
||||
// Sparsity statistics (pure CPU, reads block_map HostTensor)
|
||||
// ==================================================================
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_map_gpu(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// ==================================================================
|
||||
// Validation (only when -v=1)
|
||||
// ==================================================================
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
// CPU golden: block map + VSA LUT
|
||||
std::cout << "Building Sparge block map (CPU golden)..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
|
||||
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
|
||||
|
||||
// DtoH: LUT + valid_block_num (only for validation)
|
||||
sparge::VSALut vsa_lut_gpu{
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks, num_k_blocks}),
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks}),
|
||||
};
|
||||
lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes);
|
||||
valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes);
|
||||
|
||||
// Validate block map
|
||||
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
|
||||
{
|
||||
std::size_t bmap_mismatches = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb))
|
||||
{
|
||||
bmap_mismatches++;
|
||||
if(bmap_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " block_map mismatch at [" << b << "," << h << "," << qb
|
||||
<< "," << kb << "]: GPU="
|
||||
<< static_cast<int>(block_map_gpu(b, h, qb, kb)) << " CPU="
|
||||
<< static_cast<int>(block_relation_onehot(b, h, qb, kb))
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << " Block map mismatches: " << bmap_mismatches << " / "
|
||||
<< (batch * nhead * num_q_blocks * num_k_blocks) << std::endl;
|
||||
if(bmap_mismatches > 0)
|
||||
{
|
||||
std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate VSA LUT
|
||||
std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl;
|
||||
{
|
||||
std::size_t lut_mismatches = 0;
|
||||
std::size_t valid_mismatches = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
if(vsa_lut_gpu.valid_block_num(b, h, qb) !=
|
||||
vsa_lut_cpu.valid_block_num(b, h, qb))
|
||||
{
|
||||
valid_mismatches++;
|
||||
if(valid_mismatches <= 5)
|
||||
{
|
||||
std::cout << " valid_block_num mismatch at [" << b << "," << h
|
||||
<< "," << qb
|
||||
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
|
||||
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb))
|
||||
{
|
||||
lut_mismatches++;
|
||||
if(lut_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " LUT mismatch at [" << b << "," << h << "," << qb
|
||||
<< "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
|
||||
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << " LUT mismatches: " << lut_mismatches << std::endl;
|
||||
std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl;
|
||||
if(lut_mismatches == 0 && valid_mismatches == 0)
|
||||
{
|
||||
std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate attention output
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "\nComputing reference attention output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
std::size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nAttention validation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
Reference in New Issue
Block a user