mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE][FMHA] Add sparse attention VSA (#3341)
* add sparse attention VSA * fix the pre-commit * Add jenga test and pre-commit * add bf16 for vsa * add jenga support bf16 * remove lse arg * split kernel code to block & kernel * fix the pre-commit * fix the pre-commit * fix the copyrights * fix the copyright * fix the copyright & rename block to pipeline * fix the copyright and pipeline * remove lse & dropout & add fmt * fix the jenga&VSA code review * remove the useless code & resolved the comments * remove useless code * remove useless code * Clean up code * Remove more unused code * Re-format .hpp * Refactor codegen scripts --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
156
example/ck_tile/50_sparse_attn/CMakeLists.txt
Normal file
156
example/ck_tile/50_sparse_attn/CMakeLists.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# CMakeLists.txt for sparse attention (Jenga and VSA)
|
||||
|
||||
# Use SUPPORTED_GPU_TARGETS directly
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}")
|
||||
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
|
||||
|
||||
# Code generation scripts
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
# ============================================================================
|
||||
# Jenga Sparse Attention
|
||||
# ============================================================================
|
||||
set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd_jenga
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
# Generate list of Jenga kernels (at configure time, only list)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Jenga kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS)
|
||||
|
||||
# Generate Jenga kernel source files at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Jenga Sparse Attention kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}")
|
||||
|
||||
# Jenga Instances
|
||||
set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances")
|
||||
|
||||
add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARSE_ATTN_JENGA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Jenga Example executable
|
||||
set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# VSA Sparse Attention
|
||||
# ============================================================================
|
||||
set(SPARSE_ATTN_VSA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd_vsa
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
# Generate list of VSA kernels (at configure time, only list)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate VSA kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS)
|
||||
|
||||
# Generate VSA kernel source files at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile VSA Sparse Attention kernels"
|
||||
)
|
||||
|
||||
message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}")
|
||||
|
||||
# VSA Instances
|
||||
set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances")
|
||||
|
||||
add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARSE_ATTN_VSA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# VSA Example executable
|
||||
set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
3
example/ck_tile/50_sparse_attn/codegen/__init__.py
Normal file
3
example/ck_tile/50_sparse_attn/codegen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
73
example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
Normal file
73
example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16": "FmhaSparseFwdFp16",
|
||||
"bf16": "FmhaSparseFwdBf16",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "FmhaMasks::NoMask",
|
||||
"causal": "FmhaMasks::CausalMask",
|
||||
"generic": "FmhaMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
MODE_MAP = {"batch": "false"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
3
example/ck_tile/50_sparse_attn/codegen/ops/__init__.py
Normal file
3
example/ck_tile/50_sparse_attn/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py
Normal file
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py
Normal file
@@ -0,0 +1,867 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
|
||||
#include "kernel/fmha_fwd_jenga_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_jenga_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# jenga fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=2,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py
Normal file
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py
Normal file
@@ -0,0 +1,867 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
|
||||
#include "kernel/fmha_fwd_vsa_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_vsa_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# vsa fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=1,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
328
example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
Normal file
328
example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
Normal file
@@ -0,0 +1,328 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#include "01_fmha/mask.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile {
|
||||
inline bool is_load_tr_supported() { return is_gfx95_supported(); }
|
||||
} // namespace ck_tile
|
||||
|
||||
struct FmhaSparseFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaSparseFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaSparseFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
// Note: The following types are required by BlockFmhaPipelineProblem but not used
|
||||
// by sparse attention (bias, dropout, LSE are not supported).
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
// Note: The following types are required by BlockFmhaPipelineProblem but not used
|
||||
// by sparse attention (bias, dropout, LSE are not supported).
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// jenga
|
||||
struct fmha_jenga_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
// Dropout is not supported for sparse attention; keep args minimal.
|
||||
};
|
||||
|
||||
// vsa
|
||||
struct fmha_vsa_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
|
||||
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
// Dropout is not supported for sparse attention; keep args minimal.
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.block_relation_onehot_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.lut_ptr,
|
||||
args.valid_block_num_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_>
|
||||
struct fmha_jenga_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
};
|
||||
|
||||
struct fmha_jenga_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// VSA uses the same traits structure as Jenga; aliases for clarity
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_>
|
||||
using fmha_vsa_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_>;
|
||||
|
||||
using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
166
example/ck_tile/50_sparse_attn/generate.py
Normal file
166
example/ck_tile/50_sparse_attn/generate.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = "fmha_"
|
||||
handlers = dict(
|
||||
[
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__,
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--api",
|
||||
default="fwd_jenga",
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n"
|
||||
+ " 1: generate more instance to cover all hdim\n"
|
||||
+ " 2: Only generate instance for Flash attention integration\n"
|
||||
+ " 4: Only generate instance for PyTorch integration\n"
|
||||
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
|
||||
+ "eg. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
199
example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp
Normal file
199
example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"Jenga sparse attention supports fp16/bf16 only.");
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
// Create device memory and copy data to device
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
// Use device buffer pointers instead of host tensor data pointers
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // batch mode only
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // batch mode only
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
// Dropout not supported for sparse attention.
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_jenga_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_jenga_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
fmha_jenga_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host without changing tensor shape
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
jenga_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
jenga_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
48
example/ck_tile/50_sparse_attn/jenga_sparse_attention.h
Normal file
48
example/ck_tile/50_sparse_attn/jenga_sparse_attention.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_> vsa_sparse_attention(
|
||||
const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
423
example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp
Normal file
423
example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp
Normal file
@@ -0,0 +1,423 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Test for jenga_sparse_attention function
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Get error tolerance based on data type
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
// bf16 accumulation/rounding can be noisier in sparse patterns
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
|
||||
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
ck_tile::index_t BLKQ = block_size;
|
||||
ck_tile::index_t BLKK = block_size;
|
||||
|
||||
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Jenga kernel instances are generated for block_size=128 and hdim=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate number of Q and K blocks
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Jenga Sparse Attention Test]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
|
||||
<< std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " sparsity: " << sparsity << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors (using BHSD layout when i_perm=true)
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
// Block relation onehot: [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
|
||||
// Initialize tensors with random values
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Initialize block_relation_onehot with sparse pattern
|
||||
std::mt19937 rng(seed + 100);
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
ck_tile::index_t total_blocks = 0;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
bool is_diagonal = (qb == kb && qb < num_k_blocks);
|
||||
bool random_active = (dist(rng) > sparsity);
|
||||
|
||||
if(is_diagonal || random_active)
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
|
||||
active_blocks++;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// Run kernel
|
||||
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validation
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
// Compare results
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
486
example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp
Normal file
486
example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp
Normal file
@@ -0,0 +1,486 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Test for vsa_sparse_attention function
|
||||
// Based on the Python test: test_jenga_attention.py
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel)
|
||||
template <typename T>
|
||||
void block_map_to_lut(
|
||||
const ck_tile::HostTensor<T>& block_map, // [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<int32_t>& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel
|
||||
ck_tile::HostTensor<int32_t>& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel
|
||||
ck_tile::index_t num_block_k)
|
||||
{
|
||||
auto lengths = block_map.get_lengths();
|
||||
ck_tile::index_t B = lengths[0];
|
||||
ck_tile::index_t H = lengths[1];
|
||||
ck_tile::index_t Q = lengths[2];
|
||||
|
||||
for(ck_tile::index_t b = 0; b < B; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < H; ++h)
|
||||
{
|
||||
for(ck_tile::index_t q = 0; q < Q; ++q)
|
||||
{
|
||||
int32_t valid_count = 0;
|
||||
int32_t prev_block = 0;
|
||||
|
||||
for(ck_tile::index_t k = 0; k < num_block_k; ++k)
|
||||
{
|
||||
T cur_block = block_map(b, h, q, k);
|
||||
if(static_cast<float>(cur_block) > 0.5f)
|
||||
{ // Check if block is active
|
||||
lut(b, h, q, valid_count) = static_cast<int32_t>(k - prev_block);
|
||||
valid_count++;
|
||||
prev_block = static_cast<int32_t>(k);
|
||||
}
|
||||
}
|
||||
valid_block_num(b, h, q) = valid_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get error tolerance based on data type
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
// bf16 accumulation/rounding can be noisier in sparse patterns
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
|
||||
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
ck_tile::index_t BLKQ = block_size;
|
||||
ck_tile::index_t BLKK = block_size;
|
||||
|
||||
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "VSA kernel instances are generated for block_size=128 and hdim=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate number of Q and K blocks
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[VSA Sparse Attention Test]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
|
||||
<< std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " sparsity: " << sparsity << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors (using BHSD layout when i_perm=true)
|
||||
// Q: [B, H, S_q, D]
|
||||
// K: [B, H_k, S_k, D]
|
||||
// V: [B, H_k, S_k, D_v]
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
// Block relation onehot: [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
|
||||
// LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel
|
||||
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
|
||||
|
||||
// Initialize tensors with random values
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Initialize block_relation_onehot with sparse pattern
|
||||
std::mt19937 rng(seed + 100);
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
ck_tile::index_t total_blocks = 0;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
// Each Q block always attends to its diagonal K block (if exists)
|
||||
// Plus random blocks based on sparsity
|
||||
bool is_diagonal = (qb == kb && qb < num_k_blocks);
|
||||
bool random_active = (dist(rng) > sparsity);
|
||||
|
||||
if(is_diagonal || random_active)
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
|
||||
active_blocks++;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// Convert block_relation_onehot to LUT format
|
||||
std::cout << "Converting block map to LUT format..." << std::endl;
|
||||
block_map_to_lut(block_relation_onehot, lut_host, valid_block_num_host, num_k_blocks);
|
||||
|
||||
// vsa_sparse_attention handles device memory internally
|
||||
|
||||
// Run kernel
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
// Print kernel name once by invoking with log_level=1.
|
||||
// This is separate from warmup/benchmark to avoid polluting timing.
|
||||
if(kname)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: vsa_sparse_attention already returns output in output_host
|
||||
|
||||
// Validation
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
// Compute scale factor
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
// Run reference implementation
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
// Compare results
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
205
example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp
Normal file
205
example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"VSA sparse attention supports fp16/bf16 only.");
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
// Create device memory and copy data to device
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
lut_buf.ToDevice(TKV_block_idx.data());
|
||||
valid_block_num_buf.ToDevice(TKV_blocks.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
// Use device buffer pointers instead of host tensor data pointers
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // batch mode only
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // batch mode only
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
// Dropout not supported for sparse attention.
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_vsa_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
fmha_vsa_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host without changing tensor shape
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
vsa_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
vsa_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
@@ -30,4 +30,5 @@ add_subdirectory(36_pooling)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
add_subdirectory(50_sparse_attn)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user