From d96f632fa18f61f6b6aba33ffb92f36434a4ad32 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 5 Dec 2025 10:31:12 +0800 Subject: [PATCH] [CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153) * Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 05292b3604e143e98ec2cb67edb2e3d2ad1d6ecb] --- example/ck_tile/01_fmha/CMakeLists.txt | 34 - .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 20 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 827 ++++++++++++------ .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 616 ------------- example/ck_tile/01_fmha/fmha_fwd.hpp | 94 ++ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 60 -- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 73 -- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 179 ---- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 - .../core/algorithm/coordinate_transform.hpp | 82 ++ include/ck_tile/core/tensor/tile_window.hpp | 9 +- .../ck_tile/ops/fmha/block/block_masking.hpp | 13 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 48 - .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 134 +-- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 32 +- .../pipeline/block_fmha_pipeline_enum.hpp | 1 + .../pipeline/block_fmha_pipeline_problem.hpp | 43 - .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 - include/ck_tile/remod.py | 2 +- 22 files changed, 890 insertions(+), 1449 deletions(-) delete mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 0c8102a70b..6e7d69281d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") - -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp - ${FMHA_FWD_V3_INSTANCES} -) - -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -fgpu-flush-denormals-to-zero - -Wno-undefined-func-template - --save-temps -) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) - -check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) -if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 - ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 - ) -endif() - -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 333579ec8d..a3cfe2622a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -30,16 +30,24 @@ _MASK_MAP = { } -def get_mask_map(mask: str): - if mask == "generic": +def get_mask_map(mask_impl: str): + if mask_impl == "generic": return _MASK_MAP - elif mask == "simplified": + elif mask_impl == "simplified": return _MASK_SIMPLIFIED_MAP else: assert False return None +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + _MASK_CHECK_MAP = { "no": "t.mask_type == mask_enum::no_mask", "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", @@ -62,6 +70,10 @@ def get_mask_check_map(mask: str): return None +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", @@ -122,6 +134,7 @@ PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", } PIPELINE_ENUM_MAP = { @@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = { "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 17d4f6e1d7..c00bdcea3b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -8,14 +8,13 @@ import os from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, BIAS_CHECK_MAP, - get_mask_check_map, BOOL_MAP, PIPELINE_MAP, PIPELINE_ENUM_MAP, @@ -23,6 +22,8 @@ from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, QSCALE_CHECK_MAP, QSCALE_MAP, ) @@ -48,79 +49,79 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY = """ +FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_qscale}, - {F_occupancy}, - {F_skip}>; +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_qscale}, + {F_occupancy}, + {F_skip}>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; +using fmha_mask = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, + fmha_variant, + fmha_mask, {F_trload}, - fmha_trait_{F_idx}>; + fmha_traits>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; +using fmha_kernel = {F_kernel}; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}; + using k_ = fmha_kernel; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = {F_kargs_creator}(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -130,40 +131,47 @@ float fmha_fwd_(const ck_tile::stream_config& s, fm """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" -FMHA_FWD_API = """ +FMHA_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py #include #include -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ +#include "fmha_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { int device; auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device"); return false; - }} + } - hipDeviceProp_t props{{}}; + hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device properties"); return false; - }} + } num_cus = props.multiProcessorCount; return true; -}} +} -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ +} +} // namespace +""" +FMHA_FWD_API_FUNC_TEMPLATE = """ +namespace {{ +float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -182,6 +190,28 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& {F_dispatch} return r; }} +}} // namespace +""" +FMHA_FWD_API_FOOTER_TEMPLATE = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const std::string device_name = ck_tile::get_device_name(); + + const bool is_swa = (traits.mask_type != mask_enum::no_mask) and + ((0 < args.window_size_left) or (0 < args.window_size_right)); + const bool can_dispatch_v3 = + (device_name.compare(0, 6, "gfx950") == 0) and + (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and + (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and + (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and + (args.hdim_v == 128); + if ({F_is_v3_enabled} and can_dispatch_v3) {{ + return fmha_fwd_v3(traits, args, config); + }} else {{ + return fmha_fwd_v2(traits, args, config); + }} +}} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -261,7 +291,7 @@ class FmhaFwdApiTrait: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.pipeline_tag in ["qr_async", "qr_async_trload", "qr_async_trload_v3"]: if self.spad == "t": return "true" # always support else: @@ -294,7 +324,7 @@ class FmhaFwdApiTrait: return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag == "qr_async_trload": + elif self.pipeline_tag in ["qr_async_trload", "qr_async_trload_v3"]: if self.skpad == "t": return "true" else: @@ -310,7 +340,7 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -327,7 +357,7 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -429,9 +459,8 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self): self.pool = OrderedDict() - self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: hdim = trait.hdim, trait.bn1 @@ -443,19 +472,60 @@ class FmhaFwdApiPool: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - @property - def api(self) -> str: + def get_num_traits( + self, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + + def render( + self, func_name, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> str: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any FmhaFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + per_arch = str() - for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): per_dtypes = str() - for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): per_hdim_case = str() for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( - pool_by_dtype.items() + item for item in pool_by_dtype.items() if has_traits(item[1]) ): - max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) inners = str() - for i_trait, trait in enumerate(pool_by_hdim): + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter_fn(trait)] + ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), F_arch=arch, @@ -463,8 +533,8 @@ class FmhaFwdApiPool: F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], @@ -506,10 +576,9 @@ class FmhaFwdApiPool: F_arch=arch, F_dtype_case=indent(per_dtypes), ) - if not per_arch: - # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) @dataclass @@ -548,18 +617,32 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: F_arch: ArchTrait - F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type F_mode: str # value from MODE_MAP F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline - mask_impl: str - @property - def template(self) -> str: - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "ck_tile::FmhaFwdV3Kernel" + else: + return "ck_tile::FmhaFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "fmha_fwd_v3_create_kargs_and_grids" + else: + return "fmha_fwd_create_kargs_and_grids" + + def render(self) -> str: + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], @@ -594,10 +677,12 @@ class FmhaFwdKernel: F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), ) @property @@ -644,16 +729,179 @@ class FmhaFwdKernel: ) -class KernelComponentFactoryGfx9: +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + tile: FmhaFwdTileSize + pipeline: FmhaFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], +) -> bool: + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> FmhaFwdKernel: + return FmhaFwdKernel( + F_arch=arch, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> list[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): + if ( + kernel_ctx.pipeline.F_bias != "no" + or kernel_ctx.pipeline.F_dropout == "t" + ): + False + return True + + def check_feature( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + # logits_soft_cap is only allowed if no bias + if not ( + ( + kernel_ctx.pipeline.F_logits == "t" + and kernel_ctx.pipeline.F_bias == "no" + ) + or kernel_ctx.pipeline.F_logits == "f" + ): + return False + return True + + return [check_mode, check_hdim, check_feature] + + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + def check_hdim_tile( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support + if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 != 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) + and kernel_ctx.tile.F_bm0 != 128 + ) + ): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + return False + return True + + rules.append(check_hdim_tile) + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + _AVAILABLE_PIPELINES = ( + CompatibilityRuleFactoryGfx9._AVAILABLE_PIPELINES + | frozenset({"qr_async_trload", "qr_async_trload_v3"}) + ) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactoryGfx9.get_rules() + + def check_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + ) + ): + return False + + # only qr_async_trload_v3 use km0=256 & 8-warps + is_v3_dedicated_tile = ( + kernel_ctx.tile.F_bm0 == 256 + and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 + and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 + ) # fmt: skip + is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" + return is_v3_dedicated_tile == is_v3_pipeline + + rules.extend([check_tile_pipeline]) + return rules + + +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) + _DT_FP32 = ("fp32",) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8 = ("fp8",) + _DT_FP8BF16 = ("fp8bf16",) + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return ( + cls._DT_FP32 + + cls._DT_FP16_BF16 + + cls._DT_FP8 + + cls._DT_FP8BF16 + + cls._DT_FP8FP32 + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp32"]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP32: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -666,7 +914,7 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: return { ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -682,30 +930,32 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ["fp32"]: + if dtype in cls._DT_FP32: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -718,7 +968,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -743,7 +993,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], @@ -755,21 +1005,33 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO - None - else: - assert False + pass return pipelines -class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 +): arch = ArchTrait("gfx950") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + # add tile for qr_async_trload_v3 + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + return result + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( dtype, hdim, hdim_v, receipt, mask_impl ) - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -788,15 +1050,31 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + return pipelines -class KernelComponentFactoryGfx12: +class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp16", "bf16"]: + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -805,25 +1083,27 @@ class KernelComponentFactoryGfx12: (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { # bm0, bn0, bk0, bn1, bk1, (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = [] - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -835,23 +1115,21 @@ class KernelComponentFactoryGfx12: ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - else: - assert False return pipelines -class CustomFactory(KernelComponentFactoryGfx9): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in cls._DT_FP16_BF16: if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -874,150 +1152,162 @@ def get_factory(target: str): raise Exception(f"Unsupported device target {target}") +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + # Flash attention integration + if receipt in (2, 3): + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + return cond + + return Product(name="Flash attention integration", rule=fit) + # PyTorch integration + elif receipt == 4: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="PyTorch integration", rule=fit) + # Aiter(mha_fwd) integration + elif receipt == 100: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_fwd) integration", rule=fit) + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) + # aiter::mha_fwd C++ api integration + elif receipt == 600: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + elif receipt == 888: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp8bf16", "fp8fp32"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="receipt = 888", rule=fit) + # fp32 only, all variations + elif receipt == 800: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="fp32 only, all variations", rule=fit) + # fp32 only, minimal set of parameters + elif receipt == 801: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= problem_ctx.hdim in [48, 128] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_bias == "no" + cond &= kernel_ctx.pipeline.F_lse == "f" + cond &= kernel_ctx.pipeline.F_dropout == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_mask == "s_no" + return cond + + return Product(name="fp32 only, minimal set of parameters", rule=fit) + # Don't build fp32 by default + else: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return problem_ctx.dtype != "fp32" + + return Product(name="Default", rule=fit) + + def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool() factories = get_factories_for_targets(targets, get_factory) - for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue for tile, next_tile in zip(tiles, tiles[1:]): assert next_tile.F_bm0 >= tile.F_bm0, ( "Tiles must be ordered by increasing bm0" ) + for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue - if factory.arch.name.startswith("gfx9") and dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = FmhaFwdKernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) + kernel_ctx = KernelContext( + tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + product = get_product(receipt) + + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): + continue + + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - cond &= pipeline.F_skip == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - cond &= mode == "batch" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - elif receipt == 888: - cond = dtype in ["fp8bf16", "fp8fp32"] - cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == "fp32" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == "fp32" - cond &= hdim in [48, 128] - cond &= mode == "batch" - cond &= pipeline.F_bias == "no" - cond &= pipeline.F_lse == "f" - cond &= pipeline.F_dropout == "f" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - cond &= pipeline.F_mask == "s_no" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == "fp32": - continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -1026,11 +1316,34 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + update_file(autogen_dir / kernel.filename, kernel.render()) -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) +def write_fwd_api( + api_pool: FmhaFwdApiPool, + autogen_dir: Path, +) -> None: + def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: + return trait.pipeline_tag == "qr_async_trload_v3" + + def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: + return not accept_only_v3(trait) + + content = "".join( + [ + FMHA_FWD_API_HEADER, + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + # NOTE: enable v3 pipelines when ready + # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + False + ] + ), + ] + ) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) def write_blobs( diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp deleted file mode 100644 index c510b36bb5..0000000000 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -auto parse_cmd_args(int argc, char* argv[]) -> std::pair -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q & k") - .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "0", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "0", "permute output") - .insert("causal", "0", "0: no mask, 1: causal mask") - .insert("v", "1", "0:no verify, 1:verify") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "30", "number of iterations to benchmark the kernel") - // Optional effective seqlen override (exclude PAD) for batch mode - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); - - bool result = arg_parser.parse(argc, argv); - return std::make_pair(result, arg_parser); -} - -enum class TensorLayout -{ - bhsd, - bshd, -}; - -std::ostream& operator<<(std::ostream& stream, TensorLayout layout) -{ - switch(layout) - { - case TensorLayout::bhsd: return stream << "bhsd"; - case TensorLayout::bshd: return stream << "bshd"; - default: return stream << "unknown"; - } -} - -struct Problem -{ - explicit Problem(const ck_tile::ArgParser& args) - { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); - if(seqlen_k < 0) - { - seqlen_k = seqlen_q; - } - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); - if(nhead_kv < 0) - { - nhead_kv = nhead_q; - } - hdim = args.get_int("d"); - softmax_scale = args.get_float("scale_s"); - if(softmax_scale == .0f) - softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); - } - - std::vector get_query_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::vector get_key_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_value_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_output_shape() const - { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; - ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_kv; - ck_tile::index_t hdim; - float softmax_scale; - mask_info mask; - TensorLayout input_layout; - TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; -}; - -struct RunConfig -{ - explicit RunConfig(const ck_tile::ArgParser& args) - { - seed = args.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - kernel_warmup = args.get_int("warmup"); - kernel_repeat = args.get_int("repeat"); - verify = args.get_bool("v"); - } - - std::optional seed; - int kernel_warmup; - int kernel_repeat; - bool verify; -}; - -template -auto generate_qkv(const Problem& problem, - [[maybe_unused]] std::optional seed = std::nullopt) - -> std::tuple, - ck_tile::HostTensor, - ck_tile::HostTensor> -{ - ck_tile::HostTensor q(problem.get_query_shape()); - ck_tile::HostTensor k(problem.get_key_shape()); - ck_tile::HostTensor v(problem.get_value_shape()); - - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); - - return std::make_tuple(q, k, v); -} - -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); - - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host - -template -bool run_impl(const Problem& problem, const RunConfig& run_config) -{ - auto [q, k, v] = generate_qkv(problem, run_config.seed); - - ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); - /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v - ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q.data()); - k_buf.ToDevice(k.data()); - v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - - ck_tile::fmha_fwd_v3_args args{}; - - args.data_type = problem.data_type; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_kv = problem.nhead_kv; - args.hdim_qk = problem.hdim; - args.hdim_v = problem.hdim; - args.softmax_scale = problem.softmax_scale; - - args.window_size_left = problem.mask.left; - args.window_size_right = problem.mask.right; - args.mask_type = static_cast(problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - - ck_tile::stream_config stream_config{nullptr, - true, - /*log_level=*/0, - run_config.kernel_warmup, - run_config.kernel_repeat}; - - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); - if(!result) - { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; - return false; - } - - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); - float tflops = static_cast(flop) / 1.e9 / time; - - std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; - - if(!run_config.verify) - { - return true; - } - - // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } - - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } - - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); - - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); -} - -int main(int argc, char* argv[]) -{ - auto [parse_result, args] = parse_cmd_args(argc, argv); - if(!parse_result) - { - std::cerr << "failed to parse command line arguments" << std::endl; - } - - Problem problem(args); - RunConfig run_config(args); - - const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) - { - return run_impl(problem, run_config); - } - else - { - return run_impl(problem, run_config); - } - }; - - return !run(); -} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f279ebfcea..002d0a1035 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) +{ + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + template auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp deleted file mode 100644 index 1c0256cc0f..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" -#include "mask.hpp" - -namespace ck_tile { - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) -{ - switch(data_type) - { - case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; - case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; - default: return stream << "unknown"; - } -} - -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) -{ - if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - - return std::make_pair(false, -1.f); -} - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp deleted file mode 100644 index 54cc4960a5..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/host/stream_config.hpp" - -namespace ck_tile { - -struct fmha_fwd_v3_args -{ - enum class data_type_enum - { - fp16, - bf16 - }; - - data_type_enum data_type; - // bool is_varlen; - - index_t batch; - index_t seqlen_q; - index_t seqlen_k; - index_t nhead_q; - index_t nhead_kv; - index_t hdim_qk; - index_t hdim_v; - - float softmax_scale; - - index_t window_size_left; - index_t window_size_right; - index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and - // window_size_right == 0). - - const void* q_ptr; - index_t stride_q; - index_t nhead_stride_q; - index_t batch_stride_q; - - const void* k_ptr; - index_t stride_k; - index_t nhead_stride_k; - index_t batch_stride_k; - - const void* v_ptr; - index_t stride_v; - index_t nhead_stride_v; - index_t batch_stride_v; - - void* o_ptr; - index_t stride_o; - index_t nhead_stride_o; - index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] -}; - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp deleted file mode 100644 index 19b8dfed4e..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/container/sequence.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" - -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ - } - -namespace ck_tile { - -template -struct fmha_fwd_v3_problem_traits; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::bf16_t; - using acc_dtype = float; - using o_dtype = ck_tile::bf16_t; - using lse_dtype = float; -}; - -template -struct fmha_fwd_v3_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; - using fmha_warp_gemm_shape = sequence<32, 32, 16>; - using fmha_block_warps = sequence<8, 1, 1>; - - using fmha_shape = TileFmhaShape; - - using fmha_traits = TileFmhaFwdV3Traits; - - using fmha_mask = GenericAttentionMask; - - using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - fmha_shape, - IsVariableSeqlen, - fmha_mask, - fmha_traits>; - - using fmha_pipeline = BlockFmhaFwdV3Pipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using kernel = FmhaFwdV3Kernel; -}; - -template -float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) -{ - /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly - /// maximizes the kernel's performance. - int remap_opt = 2; - if(args.mask_type != static_cast(mask_enum::no_mask) && - ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) - { - if(65536 <= args.seqlen_q) - { - remap_opt = 0; - } - else - { - remap_opt = 1; - } - } - - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_qk, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_kv, - args.softmax_scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; - - return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template -std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, - const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp deleted file mode 100644 index 463c52b824..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp deleted file mode 100644 index acf79e43f4..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp deleted file mode 100644 index a6366209b2..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp deleted file mode 100644 index a83e37cc68..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 81eea60c2f..29a7e2593e 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing& printf("}"); } +template +struct functor_transform : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + Functor functor_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr functor_transform() = default; + + CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor, + const LowLength& low_length) + : functor_{functor}, up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = functor_(idx_up[number<0>{}]); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + calculate_lower_index(idx_low, up_idx); + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // Note: When using functor_transform, ensure that the transformed coordinates + // are always valid for vectorized load/store operations. + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } +}; + //******************************************************************************************************* template @@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le return offset{low_length, offset_length}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor, + const LowLength& low_length) +{ + return functor_transform{functor, low_length}; +} + } // namespace ck_tile #include "ck_tile/core/algorithm/indexing_adaptor.hpp" diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index e80267faec..d39da82a62 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths } }; -template +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1a79aebef5..756968871d 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask mdiv y_ratio_mdiv; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..9890d1f2e4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -73,54 +73,6 @@ struct FmhaFwdKernel #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + - (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)) + (kUseTrLoad ? "_trload" : "_ntrload"); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index df17bdd879..f981c54bd8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,6 +12,8 @@ namespace ck_tile { +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct FmhaFwdV3Kernel { @@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; // Optional cumulative padded sequence starts (including PAD tokens) // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasMask) @@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - // TODO: this may need tuning - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1)); } else { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + batch_size); } } @@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel // FmhaPipeline::kN1); // assume that num_tile_n1 is always 1 - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { const index_t i_nhead = blockIdx.x; - const index_t i_block = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.z; - return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } else { @@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel const index_t i_block = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } } @@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + batch_offset_o = query_start * kargs.stride_o; + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 8bf24be386..68ec349694 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; + static_assert(is_generic_attention_mask_v); static_assert(std::is_same_v, "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); using BlockFmhaShape = ck_tile::remove_cvref_t; + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; @@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; + static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kStoreLSE && !kHasDropout && + (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && + !kSkipMinSeqlenQ), + "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index da0fa16ee1..659bdd995b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_ASYNC_TRLOAD_V3, }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index b90b760a0d..7c4a921b70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdV3PipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index b9e18de1e5..df33a93696 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdV3Traits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - } // namespace ck_tile diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index affa6d987b..aeec7bd471 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -90,7 +90,7 @@ submodule = submodule_t() # formatting format_procs = [] for x in all_files: - dos2unix = f"python -m dos2unix {str(x)} {str(x)}" + dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}" clang_format = f"clang-format -style=file -i {str(x)}" # One process to avoid race conditions. cmd = f"{dos2unix} && {clang_format}"