mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Support 64x128 tile size in sparge fwd for Jenga and VSA paths
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# CMakeLists.txt for sparse attention (Jenga and VSA)
|
||||
#Copyright(c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
#SPDX - License - Identifier : MIT
|
||||
#CMakeLists.txt for sparse attention(Jenga and VSA)
|
||||
|
||||
# Use SUPPORTED_GPU_TARGETS directly
|
||||
#Use SUPPORTED_GPU_TARGETS directly
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
@@ -16,7 +16,7 @@ endif()
|
||||
|
||||
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
|
||||
|
||||
# Code generation scripts
|
||||
#Code generation scripts
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
@@ -88,11 +88,62 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge Jenga (64x128 tile)
|
||||
# ============================================================================
|
||||
set(SPARGE_JENGA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api sparge_fwd_jenga
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARGE_JENGA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Sparge Jenga kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}")
|
||||
|
||||
set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances")
|
||||
|
||||
add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARGE_JENGA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + Jenga Example executable
|
||||
set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
|
||||
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
@@ -164,11 +215,62 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Sparge VSA (64x128 tile)
|
||||
# ============================================================================
|
||||
set(SPARGE_VSA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api sparge_fwd_vsa
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Sparge VSA kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARGE_VSA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Sparge VSA kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}")
|
||||
|
||||
set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances")
|
||||
|
||||
add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARGE_VSA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + VSA Example executable
|
||||
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
|
||||
799
example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py
Normal file
799
example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py
Normal file
@@ -0,0 +1,799 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
|
||||
#include "kernel/fmha_fwd_jenga_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_jenga_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# jenga fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=2,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
799
example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py
Normal file
799
example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py
Normal file
@@ -0,0 +1,799 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
|
||||
#include "kernel/fmha_fwd_vsa_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_vsa_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# vsa fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=1,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
@@ -277,6 +277,9 @@ struct fmha_jenga_fwd_traits
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// sparge jenga
|
||||
float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
@@ -322,6 +325,9 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// sparge vsa
|
||||
float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
|
||||
189
example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp
Normal file
189
example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparge_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"Jenga sparse attention supports fp16/bf16 only.");
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_jenga_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_jenga_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
sparge_jenga_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
jenga_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
jenga_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
27
example/ck_tile/50_sparse_attn/jenga_sparge_attention.h
Normal file
27
example/ck_tile/50_sparse_attn/jenga_sparge_attention.h
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "jenga_sparge_attention.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
@@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "128", "Sparge BLKQ")
|
||||
.insert("blkq", "64", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
@@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, "
|
||||
std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
@@ -247,7 +247,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
@@ -268,7 +268,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
@@ -292,7 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
jenga_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "vsa_sparge_attention.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
@@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "128", "Sparge BLKQ")
|
||||
.insert("blkq", "64", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
@@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, "
|
||||
std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
@@ -251,7 +251,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
@@ -273,7 +273,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
@@ -298,7 +298,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
|
||||
195
example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp
Normal file
195
example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "vsa_sparge_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"VSA sparse attention supports fp16/bf16 only.");
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
lut_buf.ToDevice(TKV_block_idx.data());
|
||||
valid_block_num_buf.ToDevice(TKV_blocks.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_vsa_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
sparge_vsa_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
vsa_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
vsa_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
28
example/ck_tile/50_sparse_attn/vsa_sparge_attention.h
Normal file
28
example/ck_tile/50_sparse_attn/vsa_sparge_attention.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
@@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
int seqlen_k_start = kv_block_idx_ptr[0] * kM0;
|
||||
int seqlen_k_start = kv_block_idx_ptr[0] * kN0;
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
|
||||
Reference in New Issue
Block a user