Merge commit '1e77695fe87c4d4d979859a91f29fd29aebbbcbc' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 21:11:55 +00:00
parent 26e9ec020f
commit de1ee4af17
51 changed files with 1823 additions and 1241 deletions

View File

@@ -1,8 +1,8 @@
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
# Currently only gfx9 and gfx12 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
@@ -12,6 +12,7 @@ set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(BUILD_TESTING)
# Build instances of all APIs for tests
message(DEBUG "Enabling all FWD APIs of CK Tile FMHA for because testing is enabled")
set(FMHA_FWD_ENABLE_APIS "all")
endif()
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
@@ -36,15 +37,19 @@ file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
list(JOIN INST_TARGETS , FMHA_TARGETS_ARG)
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,128,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api bwd
--receipt 3
--optdim 32,64,96,128,256
@@ -67,7 +72,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
endif()
execute_process(
@@ -76,7 +81,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
@@ -89,6 +94,7 @@ add_custom_command(
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile FMHA FWD kernels"
)
add_custom_command(
@@ -96,6 +102,7 @@ add_custom_command(
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile FMHA BWD kernels"
)
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import dataclass, field
from typing import Any, List, Callable
@dataclass(frozen=True)
class ArchTrait:
name: str
preprocessor_check: str = field(default=None)
device_name_check: str = field(default=None)
tag: str = field(default=None)
filename_suffix: str = field(default=None)
def __post_init__(self):
if self.preprocessor_check is None:
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
if self.device_name_check is None:
object.__setattr__(
self,
"device_name_check",
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
)
if self.tag is None:
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
if self.filename_suffix is None:
object.__setattr__(self, "filename_suffix", f"_{self.name}")
def get_factories_for_targets(
targets: List[str], get_factory: Callable[[str], Any]
) -> List[Any]:
factories = dict()
for target in targets:
factory = get_factory(target)
factories[factory.arch.name] = factory
# Place more specific architectures first
factories = sorted(
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
)
return factories

View File

@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
@@ -21,6 +21,7 @@ from codegen.cpp_symbol_map import (
BOOL_MAP,
PIPELINE_ENUM_MAP,
)
from codegen.utils import update_file
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
@@ -441,7 +442,7 @@ class FmhaFwdApiPool:
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += " (void)t ; (void)s ; (void)a;"
per_dtypes += " (void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
@@ -720,15 +721,20 @@ def get_fwd_blobs(
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
@@ -737,7 +743,12 @@ def write_blobs(
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)

View File

@@ -3,13 +3,14 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Literal, Any
from collections import defaultdict
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
get_mask_check_map,
@@ -22,16 +23,20 @@ from codegen.cpp_symbol_map import (
BWD_DTYPE_MAP,
BOOL_MAP,
)
from codegen.utils import update_file
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::
@@ -132,10 +137,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_maxq},
{F_bn0}>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -144,67 +147,68 @@ float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::kMaxSeqLenQ;
}}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetName();
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp"
FMHA_BWD_API = """
#include <iostream>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_, typename Arch>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_, Arch>(s_, a); }}
);
}}
else
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }}
);
}}
}}
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
float r = -1;
{F_dispatch}
return r;
@@ -212,23 +216,22 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
"""
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_=0) -> str:
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str:
lines = [
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
f"{if_(if_i)}({F_cond})",
"{",
*[" " + line for line in F_body.split("\n") if line.strip() != ""],
indent(F_body),
"}",
]
return "\n".join(" " * indent + line for line in lines) + "\n"
return "\n".join(lines) + "\n"
FMHA_BWD_API_INNER_DISPATCH = """
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
return r;
}}
"""
@@ -283,6 +286,7 @@ class FmhaBwdDQDKDVTileSize:
@dataclass(frozen=True)
class FmhaBwdDQDKDVKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -302,6 +306,7 @@ class FmhaBwdDQDKDVKernel:
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
@@ -399,43 +404,97 @@ class FmhaBwdDQDKDVKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if dtype == "fp32" and tr_load == "f":
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
] # fmt: skip
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f":
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
] # fmt: skip
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t":
return [
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
] # fmt: skip
else:
class KernelComponentFactoryBase:
pass
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait(
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
)
@staticmethod
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if tr_load == "t":
return []
if dtype in ["fp32"]:
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4,bhdq,bhdv,
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
] # fmt: skip
if dtype in ["fp16", "bf16"]:
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
] # fmt: skip
return []
class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
arch = ArchTrait("gfx950")
@staticmethod
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load)
if dtype in ["fp16", "bf16"] and tr_load == "t":
results.extend([
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]) # fmt: skip
return results
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@staticmethod
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if tr_load == "t":
return []
if dtype in ["fp16", "bf16"]:
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
] # fmt: skip
return []
def get_factory(target: str):
# Place more specific architectures first
if target.startswith("gfx950"):
return KernelComponentFactoryGfx950
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
FMHA_BWD_DOT_DO_O_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} =
@@ -445,7 +504,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = M0 = */ 64,
/* BlockSize = M0 = */ {F_bm0},
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
@@ -459,10 +518,8 @@ using fmha_bwd_dot_do_o_kernel_{F_idx} =
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -471,34 +528,38 @@ float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
return k_::GetName();
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
@dataclass(frozen=True)
class FmhaBwdOGradDotOKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_bm0: int # tile size along q seqlen (block size)
F_spad: str # true/false
F_dvpad: str #
F_mode: str # value from MODE_MAP
@@ -508,8 +569,10 @@ class FmhaBwdOGradDotOKernel:
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_bm0,
F_spad=BOOL_MAP[self.F_spad],
F_dvpad=BOOL_MAP[self.F_dvpad],
F_mode=MODE_MAP[self.F_mode],
@@ -529,7 +592,7 @@ class FmhaBwdOGradDotOKernel:
return n
pn = pad_name()
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
if pn != "":
n += f"_{pn}"
else:
@@ -538,10 +601,14 @@ class FmhaBwdOGradDotOKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_convert_dq_trait_{F_idx} =
@@ -573,10 +640,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_deterministic},
{F_bn0}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -585,32 +650,34 @@ float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_confi
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
@dataclass(frozen=True)
class FmhaBwdConvertQGradKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -627,6 +694,7 @@ class FmhaBwdConvertQGradKernel:
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_bm0,
@@ -664,11 +732,12 @@ class FmhaBwdConvertQGradKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
@dataclass(frozen=True)
class FmhaBwdApiTrait:
arch: ArchTrait
idx: int # this is not a tunable, but a counter to differentiate symbol
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim: int
@@ -705,10 +774,10 @@ class FmhaBwdApiTrait:
@property
def scheck(self) -> str:
if self.mode == "group":
return "true" # always support
return "true /*spad1d is always true in group mode*/"
elif self.spad1d == "t":
return f"a.seqlen_q % {M0_1D} != 0"
else: # self.spad1d == 'f'
return f"true /*a.seqlen_q % {M0_1D} != 0*/"
else: # self.spad1d == "f"
return f"a.seqlen_q % {M0_1D} == 0"
@property
@@ -725,10 +794,17 @@ class FmhaBwdApiTrait:
else:
return f"a.hdim_v % {self.dvpad} == 0"
@property
def max_seq_q_cond(self) -> str:
if self.tile.max_seq_q != 0:
return f" && (a.seqlen_q <= {self.tile.max_seq_q})"
else:
return ""
@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
return "&& (a.seqlen_k <= 256)"
return " && (a.seqlen_k <= 256)"
else:
return ""
@@ -745,9 +821,11 @@ class FmhaBwdApiTrait:
F_dvpad = "t" if self.dvpad else "f"
return FmhaBwdOGradDotOKernel(
F_arch=self.arch,
F_idx=self.idx,
F_hdim=self.hdim,
F_dtype=self.dtype,
F_bm0=M0_1D,
F_spad=self.spad1d,
F_dvpad=F_dvpad,
F_mode=self.mode,
@@ -757,6 +835,7 @@ class FmhaBwdApiTrait:
@property
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(
F_arch=self.arch,
F_idx=self.idx,
F_hdim=self.hdim,
F_dtype=self.dtype,
@@ -782,6 +861,7 @@ class FmhaBwdApiTrait:
F_dpad = "t" if self.dpad else "f"
return FmhaBwdConvertQGradKernel(
F_arch=self.arch,
F_idx=self.idx,
F_hdim=self.hdim,
F_dtype=self.dtype,
@@ -798,28 +878,25 @@ class FmhaBwdApiTrait:
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
)
self.dq_dk_dv_pool = OrderedDict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait: FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][
trait.hdim
].append(copy.copy(trait))
hdim = trait.hdim
ts = (
self.dq_dk_dv_pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str:
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str:
inners = ""
i = 0
for trait in traits:
for i_trait, trait in enumerate(traits):
inners += FMHA_BWD_API_INNER_DISPATCH.format(
F_if=self.if_(i),
F_if=if_(i_trait),
F_arch=trait.arch,
F_mode=MODE_MAP[trait.mode],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
@@ -840,27 +917,18 @@ class FmhaBwdApiPool:
F_trload=BOOL_MAP[trait.tr_load],
F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
F_bn0=trait.tile.F_bn0,
F_max_seq_q_cond=trait.max_seq_q_cond,
F_cond_extra=trait.extra_cond,
F_bn0=trait.tile.F_bn0,
F_convert_dq_bn0=trait.convert_dq_bn0,
)
i += 1
return inners
@staticmethod
def trload_sort_key(tf):
return 0 if tf == "t" else 1 # sort 't' before 'f'
@staticmethod
def max_seq_q_sort_key(max_seq_q):
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
@staticmethod
def max_seq_q_cond(max_seq_q: int) -> str:
if max_seq_q == 0:
return "true /* no seqlen_q limit */"
else:
return f"a.seqlen_q <= {max_seq_q}"
def max_seq_q_sort_key(trait):
return (
trait.tile.max_seq_q if trait.tile.max_seq_q != 0 else 1000000
) # sort 0 to the end
@staticmethod
def dtype_cond(dtype: str) -> str:
@@ -872,42 +940,34 @@ class FmhaBwdApiPool:
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true /* no trload requirement */"}
per_tr_load = ""
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
per_max_seq_q = ""
for max_seq_q in sorted(
self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key
):
per_dtypes = ""
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
per_hdim_case = ""
for k, hdim in enumerate(
self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]
):
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
if_=k, F_cond=self.hdim_cond(hdim), F_body=inners
)
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
per_arch = ""
for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()):
per_dtypes = ""
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = ""
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key)
inners = self._api_inners(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners
)
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(
F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
)
per_tr_load += FMHA_BWD_API_COND_STATEMENT(
F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4
per_arch += FMHA_BWD_API_COND_STATEMENT(
if_i=i_arch, F_cond=arch.device_name_check, F_body=per_dtypes
)
if not per_tr_load:
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a; (void)has_load_tr;"
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch=per_tr_load)
per_arch = "(void)t; (void)s; (void)a;"
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
F_dispatch=indent(per_arch)
)
return result.replace("\n\n", "\n")
def get_bwd_blobs(
filter_list: str, receipt, mask_impl, optdim_list
targets: List[str], filter_list: str, receipt, mask_impl, optdim_list
) -> Tuple[
FmhaBwdApiPool,
List[FmhaBwdOGradDotOKernel],
@@ -922,14 +982,19 @@ def get_bwd_blobs(
filter_convert_dq = filters[1]
filter_dq_dk_dv = filters[2]
factories = get_factories_for_targets(targets, get_factory)
# use dict as ordered set
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = OrderedDict()
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = OrderedDict()
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = OrderedDict()
api_pool = FmhaBwdApiPool(mask_impl)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
for factory, dtype, tr_load in itertools.product(
factories, BWD_DTYPE_MAP.keys(), ["t", "f"]
):
tiles: Any = factory.get_dq_dk_dv_tiles(dtype, tr_load)
spad1d_options = ["f", "t"]
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
tf = ["t", "f"]
for tile, mode, mask, bias, dbias, dropout, spad1d, (
@@ -942,7 +1007,7 @@ def get_bwd_blobs(
BIAS_MAP.keys(),
tf,
DROPOUT_MAP.keys(),
tf,
spad1d_options,
dpad_options,
tf,
):
@@ -958,6 +1023,8 @@ def get_bwd_blobs(
continue
if "wg32" in dropout:
continue
if spad1d == "f" and tile.max_seq_q != 0 and tile.max_seq_q < M0_1D:
continue # max_seq_q < M0_1D requires padding
if tr_load == "t":
# tr_load can only work with 8 pad
if dpad != dvpad or dpad == 1:
@@ -970,6 +1037,7 @@ def get_bwd_blobs(
if hdim not in optdim_list:
continue
t = FmhaBwdApiTrait(
arch=factory.arch,
idx=0,
hdim=hdim,
dtype=dtype,
@@ -989,10 +1057,10 @@ def get_bwd_blobs(
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
continue
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
continue
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
continue
# Flash attention integration
if receipt == 2:
@@ -1076,10 +1144,15 @@ def get_bwd_blobs(
def write_blobs(
output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
filter_list: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
filter_list, receipt, mask_impl, optdim_list
targets, filter_list, receipt, mask_impl, optdim_list
)
update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api)
for k in kernels_dot_do_o:
@@ -1091,10 +1164,15 @@ def write_blobs(
def list_blobs(
file_path: Path, filter_list: str, receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
filter_list: str,
receipt,
optdim_list,
mask_impl,
) -> None:
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
filter_list, receipt, mask_impl, optdim_list
targets, filter_list, receipt, mask_impl, optdim_list
)
with file_path.open("a") as f:
for k in kernels_dot_do_o:

View File

@@ -1,15 +1,17 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
@@ -23,7 +25,7 @@ from codegen.cpp_symbol_map import (
BIAS_MAP,
get_mask_map,
)
from codegen.utils import update_file
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
@@ -31,13 +33,17 @@ DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -99,10 +105,8 @@ using fmha_kernel_{F_idx} =
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -110,8 +114,10 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp"
@@ -148,13 +154,13 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq
}}
}} // namespace
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
if(!get_num_cus(num_cus)) {{
return r;
}}
@@ -162,32 +168,33 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
{F_dtype_case}
}}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
FMHA_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -207,6 +214,7 @@ class CppConstraint:
@dataclass
class FmhaFwdApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
@@ -413,40 +421,35 @@ class FmhaFwdPipeline:
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool = OrderedDict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
max_bm0 = max((t.bm0 for t in traits), default=0)
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
pool_by_dtype.items()
):
max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0)
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
@@ -479,23 +482,24 @@ class FmhaFwdApiPool:
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim_v,
F_inner_dispatch=indent(inners),
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
per_arch += FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
if not per_tr_load:
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
per_arch = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch))
@dataclass
@@ -533,6 +537,7 @@ class FmhaFwdTileSize:
@dataclass
class FmhaFwdKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -545,6 +550,7 @@ class FmhaFwdKernel:
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
@@ -596,10 +602,11 @@ class FmhaFwdKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
arch=self.F_arch,
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
@@ -627,12 +634,16 @@ class FmhaFwdKernel:
)
class KernelComponentFactory:
class KernelComponentFactoryGfx9:
arch = ArchTrait(
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp32":
if dtype in ["fp32"]:
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
@@ -645,10 +656,10 @@ class KernelComponentFactory:
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype == "fp16" or dtype == "bf16":
elif dtype in ["fp16", "bf16"]:
return {
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
@@ -656,18 +667,18 @@ class KernelComponentFactory:
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif dtype == "fp8" or dtype == "fp8bf16":
elif dtype in ["fp8", "fp8bf16"]:
return {
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
elif dtype == "fp8fp32":
elif dtype in ["fp8fp32"]:
return {
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
@@ -680,7 +691,7 @@ class KernelComponentFactory:
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: currently for qr pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ["fp32"]:
@@ -719,18 +730,8 @@ class KernelComponentFactory:
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
and logits == "f"
and bias == "no"
and dropout == "f"
and skip == "f"
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(
@@ -746,29 +747,128 @@ class KernelComponentFactory:
return pipelines
class CustomFactory(KernelComponentFactory):
class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
arch = ArchTrait("gfx950")
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
pipelines = KernelComponentFactoryGfx9.get_pipelines(
dtype, hdim, hdim_v, receipt, mask_impl
)
if dtype in ["fp16", "bf16"]:
squant = "f"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
):
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
and logits == "f"
and bias == "no"
and dropout == "f"
and skip == "f"
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
return pipelines
class KernelComponentFactoryGfx12:
arch = ArchTrait("gfx12")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype in ["fp8", "fp8bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype in ["fp8fp32"]:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
else:
return None
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
pipelines = []
if dtype in ["fp16", "bf16"]:
squant = "f"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactoryGfx9):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
def get_factory(target: str):
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1":
return CustomFactory
# Place more specific architectures first
if target.startswith("gfx950"):
return KernelComponentFactoryGfx950
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
factories = get_factories_for_targets(targets, get_factory)
for dtype in FWD_DTYPE_MAP.keys():
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
@@ -791,7 +891,8 @@ def get_fwd_blobs(
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != "no" or pipeline.F_dropout == "t":
continue
if dtype != "fp32":
if factory.arch.name.startswith("gfx9") and dtype != "fp32":
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
if pipeline.tag != "qr_async_trload" and (
((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128)
or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)
@@ -811,6 +912,7 @@ def get_fwd_blobs(
):
continue
k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -918,19 +1020,33 @@ def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
api_pool, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
_, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -1,13 +1,16 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
FWD_DTYPE_MAP,
@@ -16,16 +19,21 @@ from codegen.cpp_symbol_map import (
LAYOUT_MAP,
ROPE_CHECK_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_ARCH,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
@@ -55,10 +63,8 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
float fmha_fwd_appendkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -66,31 +72,37 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API = """
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """{F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
arch: ArchTrait
# sync with fmha_fwd_appendkv_traits, to generate fallback calls
hdim: str
dtype: str # data type
bs: int # tile size along q seqlen
@@ -178,62 +190,70 @@ class FmhaFwdAppendKVPipeline:
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool = OrderedDict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
def register_traits(self, trait: FmhaFwdAppendKVApiTrait) -> None:
hdim = trait.hdim
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope],
F_bs=trait.bs,
F_bsk=trait.bsk,
F_bd=trait.bd,
F_bdv=trait.bdv,
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
inners = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope],
F_bs=trait.bs,
F_bsk=trait.bsk,
F_bd=trait.bd,
F_bdv=trait.bdv,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_hdim_v=hdim,
F_inner_dispatch=indent(inners),
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
per_arch += FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
if not per_dtypes:
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_dtypes += " (void)t ; (void)s ; (void)a;"
per_arch = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
F_dispatch=per_dtypes
F_dispatch=indent(per_arch)
)
@@ -254,6 +274,7 @@ class FmhaFwdAppendKVTileSize:
@dataclass
class FmhaFwdAppendKVKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -265,6 +286,7 @@ class FmhaFwdAppendKVKernel:
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bs=self.F_tile.F_bs,
@@ -293,10 +315,11 @@ class FmhaFwdAppendKVKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
arch=self.F_arch,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
@@ -313,31 +336,26 @@ class FmhaFwdAppendKVKernel:
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == "fp8" or dtype == "bf8":
return {
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
else:
return None
class KernelComponentFactoryBase:
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype in ["fp8", "bf8"]:
return {
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
else:
return None
def get_fwd_appendkv_blobs(
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
@@ -347,19 +365,18 @@ def get_fwd_appendkv_blobs(
if dtype in ["fp16", "bf16"]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ["row", "col"]:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
for vlayout, pagedkv in itertools.product(["row"], ["t", "f"]):
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline("row", "t", "t", "t", "t", "no", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
@@ -367,18 +384,45 @@ def get_fwd_appendkv_blobs(
assert False
return pipelines
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
def get_factory(target: str):
# Place more specific architectures first
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
def get_fwd_appendkv_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
for pipeline in factory.get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -418,18 +462,23 @@ def get_fwd_appendkv_blobs(
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
update_file(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(
kernel_filter, receipt, mask_impl, optdim_list
targets, kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
@@ -437,11 +486,16 @@ def write_blobs(
def list_blobs(
file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_appendkv_blobs(
kernel_filter, receipt, mask_impl, optdim_list
targets, kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

View File

@@ -3,12 +3,14 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
PIPELINE_ENUM_MAP,
@@ -21,32 +23,29 @@ from codegen.cpp_symbol_map import (
get_mask_map,
BOOL_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
from codegen.ops.fmha_fwd import (
FmhaFwdTileSize,
DTYPE_BITS,
K0_MAX_SUBMAX_MAP,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_ARCH,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {
32: 32,
64: 64,
96: 128,
128: 128,
# 160: 160,
256: 256,
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr": "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
}
FMHA_FWD_SPLITKV_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
@@ -113,17 +112,15 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}}; // struct instance
}} // anonymous namespace
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-compare"
@@ -147,7 +144,7 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
#pragma clang diagnostic pop
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
@@ -165,14 +162,20 @@ void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, f
}}
template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
namespace {{
@@ -213,18 +216,16 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}}; // struct instance
}} // anonymous namespace
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if (a.num_splits <= 8) {{
instance<3>::run(s, a);
@@ -240,73 +241,79 @@ void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_conf
}}
template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API = """
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_, typename Arch>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_, Arch>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_, Arch>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_, Arch>(s_, a); }}
);
}}
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, {F_bn1comb}>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % {F_bn1comb} == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
return fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s, a);
}}
}}
"""
@dataclass
class FmhaFwdSplitKVApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
hdim: int
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
@@ -326,6 +333,7 @@ class FmhaFwdSplitKVApiTrait:
dpad: str
dvpad: str
pagedkv: str
bn1comb: int # tile size along v head_dim of combine kernel
@property
def name(self) -> str:
@@ -523,71 +531,80 @@ class FmhaFwdSplitKVCombinePipeline:
class FmhaFwdSplitKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool = OrderedDict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdSplitKVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
hdim = trait.hdim
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
F_if=if_k,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
inners = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_bn1comb=trait.bn1comb,
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_hdim_v=hdim,
F_inner_dispatch=indent(inners),
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
per_arch += FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
if not per_dtypes:
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_dtypes += " (void)t ; (void)s ; (void)a;"
per_arch = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(
F_dispatch=per_dtypes
F_dispatch=indent(per_arch)
)
@@ -605,6 +622,7 @@ class FmhaFwdSplitKVCombineTileSize:
@dataclass
class FmhaFwdSplitKVKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -615,8 +633,10 @@ class FmhaFwdSplitKVKernel:
@property
def template(self) -> str:
assert self.F_pipeline.F_lse == "t"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
@@ -666,36 +686,12 @@ class FmhaFwdSplitKVKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
return FmhaFwdSplitKVApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
logits=self.F_pipeline.F_logits,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
squant=self.F_pipeline.F_squant,
pagedkv=self.F_pipeline.F_pagedkv,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
)
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
@dataclass
class FmhaFwdSplitKVCombineKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -707,6 +703,7 @@ class FmhaFwdSplitKVCombineKernel:
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bn1=self.F_tile.F_bn1,
@@ -730,85 +727,33 @@ class FmhaFwdSplitKVCombineKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
elif dtype == "fp8" or dtype == "bf8":
return {
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
"32": FmhaFwdSplitKVCombineTileSize(32, -1),
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
"96": FmhaFwdSplitKVCombineTileSize(32, -1),
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype == "fp8" or dtype == "bf8":
return {
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
}
else:
return None
def get_fwd_splitkv_blobs(
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
Pipeline = FmhaFwdSplitKVPipeline
Kernel = FmhaFwdSplitKVKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
class KernelComponentFactoryBase:
@staticmethod
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: currently for qr pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
Pipeline = FmhaFwdSplitKVPipeline
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
):
pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
@@ -816,18 +761,122 @@ def get_fwd_splitkv_blobs(
assert False
return pipelines
gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
@staticmethod
def get_combine_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
Pipeline = FmhaFwdSplitKVCombinePipeline
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for spad, dvpad, lse in itertools.product(
["t", "f"], ["t", "f"], ["t", "f"]
):
pipelines.append(Pipeline("unused", spad, dvpad, lse, squant))
elif dtype in ["fp8", "bf8"]:
# no need lse kernels
for spad, dvpad in itertools.product(["t", "f"], ["t", "f"]):
pipelines.append(Pipeline("unused", spad, dvpad, "f", squant))
else:
assert False
return pipelines
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
@staticmethod
def get_combine_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
# Possible values of F_bn1: 8, 16, 32
if dtype in ["fp16", "bf16"]:
return {
"32": FmhaFwdSplitKVCombineTileSize(32, -1),
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
"96": FmhaFwdSplitKVCombineTileSize(32, -1),
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
# "160" : FmhaFwdSplitKVCombineTileSize(32, -1),
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype in ["fp8", "bf8"]:
return {
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
}
else:
return None
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
def get_factory(target: str):
# Place more specific architectures first
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
def get_fwd_splitkv_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> List[FmhaFwdSplitKVKernel]:
Kernel = FmhaFwdSplitKVKernel
gen = list()
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
@@ -839,6 +888,7 @@ def get_fwd_splitkv_blobs(
):
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -892,55 +942,34 @@ def get_fwd_splitkv_blobs(
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
return gen
def get_fwd_splitkv_combine_blobs(
kernel_filter: Optional[str], receipt, optdim_list
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list
) -> List[FmhaFwdSplitKVCombineKernel]:
Pipeline = FmhaFwdSplitKVCombinePipeline
Kernel = FmhaFwdSplitKVCombineKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for spad, dvpad, lse in itertools.product(
["t", "f"], ["t", "f"], ["t", "f"]
):
pipelines.append(Pipeline("unused", spad, dvpad, lse, squant))
elif dtype in ["fp8", "bf8"]:
# no need lse kernels
pipelines.append(Pipeline("unused", "f", "f", "f", squant))
else:
assert False
return pipelines
gen = list()
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_combine_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
for pipeline in factory.get_combine_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -980,43 +1009,102 @@ def get_fwd_splitkv_combine_blobs(
def write_single_kernel(
kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path
) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_splitkv_api(api_pool: FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
file_path.write_text(api_pool.api)
update_file(autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
filter_list: str,
receipt,
optdim_list,
mask_impl,
) -> None:
filter_list = filter_list.split("@")
filter_list.extend([""] * (2 - len(filter_list)))
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
for kernel in kernels:
combine_kernels = get_fwd_splitkv_combine_blobs(
targets, filter_list[0], receipt, optdim_list
)
for kernel in combine_kernels:
write_single_kernel(kernel, output_dir)
api_pool, kernels = get_fwd_splitkv_blobs(
filter_list[1], receipt, mask_impl, optdim_list
kernels = get_fwd_splitkv_blobs(
targets, filter_list[1], receipt, mask_impl, optdim_list
)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for kernel in kernels:
combine_ks = [
k
for k in combine_kernels
if k.F_arch == kernel.F_arch
and k.F_hdim == kernel.F_hdim
and k.F_dtype == kernel.F_dtype
and k.F_mode == kernel.F_mode
and k.F_pipeline.F_spad == kernel.F_pipeline.F_spad
and k.F_pipeline.F_dvpad == kernel.F_pipeline.F_dvpad
and k.F_pipeline.F_lse == "f"
and k.F_pipeline.F_squant == kernel.F_pipeline.F_squant
]
assert len(combine_ks) == 1, (
f"{len(combine_ks)} matching FmhaFwdSplitKVCombineKernel for {kernel}"
)
combine_kernel = combine_ks[0]
api_pool.register_traits(
FmhaFwdSplitKVApiTrait(
arch=kernel.F_arch,
pipeline_tag=kernel.F_pipeline.tag,
hdim=kernel.F_hdim,
dtype=kernel.F_dtype,
mode=kernel.F_mode,
bm0=kernel.F_tile.F_bm0,
bn0=kernel.F_tile.F_bn0,
bk0=kernel.F_tile.F_bk0,
bn1=kernel.F_tile.F_bn1,
bk1=kernel.F_tile.F_bk1,
bk0max=kernel.F_tile.F_bk0max,
vlayout=kernel.F_pipeline.F_vlayout,
logits=kernel.F_pipeline.F_logits,
mask=kernel.F_pipeline.F_mask,
bias=kernel.F_pipeline.F_bias,
lse=kernel.F_pipeline.F_lse,
squant=kernel.F_pipeline.F_squant,
pagedkv=kernel.F_pipeline.F_pagedkv,
spad=kernel.F_pipeline.F_spad,
skpad=kernel.F_pipeline.F_skpad,
dpad=kernel.F_pipeline.F_dpad,
dvpad=kernel.F_pipeline.F_dvpad,
bn1comb=combine_kernel.F_tile.F_bn1,
)
)
write_fwd_splitkv_api(api_pool, output_dir)
def list_blobs(
file_path: Path, filter_list: str, receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
filter_list: str,
receipt,
optdim_list,
mask_impl,
) -> None:
filter_list = filter_list.split("@")
filter_list.extend([""] * (2 - len(filter_list)))
with file_path.open("a") as f:
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
kernels = get_fwd_splitkv_combine_blobs(
targets, filter_list[0], receipt, optdim_list
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_fwd_splitkv_blobs(
filter_list[1], receipt, mask_impl, optdim_list
kernels = get_fwd_splitkv_blobs(
targets, filter_list[1], receipt, mask_impl, optdim_list
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

View File

@@ -3,12 +3,14 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
@@ -21,24 +23,27 @@ from codegen.cpp_symbol_map import (
BOOL_MAP,
PIPELINE_ENUM_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
from codegen.ops.fmha_fwd import (
DTYPE_BITS,
K0_MAX_SUBMAX_MAP,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_ARCH,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -98,10 +103,8 @@ using fmha_kernel_{F_idx} =
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
template<>
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -109,38 +112,35 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp"
FMHA_FWD_API = """
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_>(s, a);
}}
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class FmhaFwdApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
@@ -327,71 +327,79 @@ class FmhaFwdPipeline:
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool = OrderedDict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
hdim = trait.hdim
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
inners = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_hdim_v=trait.bn1,
F_inner_dispatch=indent(inners),
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
per_arch += FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
if not per_dtypes:
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_dtypes += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
per_arch = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch))
@dataclass
@@ -428,6 +436,7 @@ class FmhaFwdTileSize:
@dataclass
class FmhaFwdKernel:
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
@@ -440,6 +449,7 @@ class FmhaFwdKernel:
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
@@ -490,10 +500,11 @@ class FmhaFwdKernel:
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
arch=self.F_arch,
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
@@ -519,37 +530,12 @@ class FmhaFwdKernel:
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
} # fmt: skip
elif dtype == "fp8" or dtype == "bf8":
return {
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
class KernelComponentFactoryBase:
@staticmethod
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
# TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
squant = "t" if dtype == "fp8" else "f"
pipelines = []
@@ -576,19 +562,85 @@ def get_fwd_blobs(
assert False
return pipelines
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
def get_factory(target: str):
# Place more specific architectures first
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
def get_fwd_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
# if pipeline.F_pagedkv == 'f':
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
# if pipeline.F_pagedkv == "f":
# continue
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
@@ -605,6 +657,7 @@ def get_fwd_blobs(
):
continue
k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -674,27 +727,41 @@ def get_fwd_blobs(
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
api_pool, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
_, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -2,7 +2,9 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import dataclasses
import os.path as path
import textwrap
def update_file(file_path, content):
@@ -19,3 +21,51 @@ def update_file(file_path, content):
return
with open(file_path, "w") as file:
file.write(content)
def indent(code: str, indent: str = " ") -> str:
return textwrap.indent(code, indent)
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def check_duplicates_and_paddings(traits, trait):
"""Check
* if the traits list does not contain a trait with the same parameters;
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
for example, f, _t_, f, t cannot be before f, _f_, f, t.
"""
fields = [f.name for f in dataclasses.fields(trait)]
pad_fields = [f for f in fields if "pad" in f]
non_pad_fields = [f for f in fields if "pad" not in f]
for prev_trait in traits:
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
continue
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
raise Exception(f"Duplicate found {trait}")
# Check if the previous kernel can be incorrectly used before the current one
# for example, f, _t_, f, t cannot be before f, _f_, f, t
is_prev_more_restrictive = False
is_curr_more_restrictive = False
for f in pad_fields:
prev_pad = getattr(prev_trait, f)
pad = getattr(trait, f)
if isinstance(prev_pad, str):
prev_pad = 1000000 if prev_pad == "f" else 1
pad = 1000000 if pad == "f" else 1
elif isinstance(prev_pad, int):
prev_pad = 1000000 if prev_pad == 0 else prev_pad
pad = 1000000 if pad == 0 else pad
else:
assert False
if prev_pad < pad:
is_prev_more_restrictive = True
elif prev_pad > pad:
is_curr_more_restrictive = True
if is_prev_more_restrictive and not is_curr_more_restrictive:
raise Exception(
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
)

View File

@@ -453,15 +453,15 @@ struct fmha_bwd_dq_dk_dv_traits_
{
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_>
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
@@ -474,13 +474,13 @@ struct fmha_bwd_dot_do_o_traits_
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
@@ -494,13 +494,13 @@ struct fmha_bwd_convert_dq_traits_
{
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script

View File

@@ -1159,7 +1159,7 @@ struct fmha_fwd_traits_
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
@@ -1210,7 +1210,7 @@ struct fmha_fwd_pagedkv_traits_
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args);
template <ck_tile::index_t HDim_,
@@ -1259,10 +1259,10 @@ struct fmha_fwd_splitkv_traits_
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
@@ -1285,10 +1285,10 @@ struct fmha_fwd_splitkv_combine_traits_
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
@@ -1322,10 +1322,10 @@ struct fmha_fwd_appendkv_traits_
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args);
// This is the public API, will be generated by script

View File

@@ -1200,7 +1200,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
};
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
auto run_appendkv = [&]([[maybe_unused]] const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(need_append_kvcache)
{

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
@@ -38,6 +39,7 @@ assert 0 < len(handlers)
def write_blobs(
targets: List[str],
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
@@ -54,11 +56,12 @@ def write_blobs(
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
targets: List[str],
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
@@ -74,7 +77,7 @@ def list_blobs(
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
@@ -82,6 +85,12 @@ if __name__ == "__main__":
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"--targets",
default="gfx9,gfx950",
required=False,
help="list of GPU targets, separated by comma.",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
@@ -142,6 +151,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
targets = args.targets.split(",")
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
@@ -149,6 +159,7 @@ if __name__ == "__main__":
if args.list_blobs is not None:
list_blobs(
targets,
args.list_blobs,
api_list,
filter_list,
@@ -158,6 +169,7 @@ if __name__ == "__main__":
)
else:
write_blobs(
targets,
args.output_dir,
api_list,
filter_list,

View File

@@ -94,7 +94,7 @@ run_fp8_tests() {
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
@@ -105,7 +105,7 @@ run_fp8bf16_tests() {
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
@@ -114,9 +114,9 @@ run_fp8fp32_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
for hdim in 128 ; do
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}