From 1ecee378d528433f76876a892da41f07733ee935 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 8 Aug 2025 06:19:31 +0000 Subject: [PATCH] remove unnecessary files; rename some files --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 4 +- .../01_fmha/codegen/ops/fmha_fwd_decode.py | 867 ----------- example/ck_tile/01_fmha/fmha_fwd.cpp | 3 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 345 ----- include/ck_tile/core/numeric/bfloat16.hpp | 3 +- include/ck_tile/ops/fmha.hpp | 5 +- .../fmha/kernel/fmha_fwd_decode_kernel.hpp | 1334 ----------------- .../pipeline/block_fmha_pipeline_enum.hpp | 6 +- .../pipeline/block_fmha_pipeline_problem.hpp | 107 -- ...k_fmha_pipeline_qr_ks_vs_async_trload.hpp} | 6 +- ...pipeline_qr_ks_vs_async_trload_policy.hpp} | 2 +- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 49 - 12 files changed, 14 insertions(+), 2717 deletions(-) delete mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py delete mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp rename include/ck_tile/ops/fmha/pipeline/{block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp => block_fmha_pipeline_qr_ks_vs_async_trload.hpp} (99%) rename include/ck_tile/ops/fmha/pipeline/{block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp => block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp} (99%) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index eb490d806f..a0f6dd7f58 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,7 +115,7 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -124,7 +124,7 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::DECODE_QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py deleted file mode 100644 index cae6ac74f4..0000000000 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ /dev/null @@ -1,867 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass -import fnmatch -import itertools -from pathlib import Path -from typing import List, Optional, Tuple, Union - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * - -from codegen.ops.fmha_fwd import ( - FmhaFwdTileSize, - FmhaFwdApiTrait, - FMHA_FWD_KERNEL_HEADER, - FMHA_FWD_API_PER_DTYPE, - FMHA_FWD_API_PER_HDIM_CASE, -) - - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 256: 256 -} - -SEQLENQ_MAP = { - "16" : "16", - "32" : "32", - # "64" : "64" - "128" : "128", - # "256" : "256", -} - -FMHA_FWD_DECODE_PIPELINE_MAP = { - "decode_qr" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS", -} - -FMHA_FWD_DECODE_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; - -namespace {{ -template -struct instance {{ -using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_trait = ck_tile::TileFmhaFwdDecodeTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - /*kHasBiasGrad=*/false, - {F_lse}, - {F_squant}, - {F_pagedkv}, - kHasUnevenSplits, - kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}>; - -using fmha_pipeline_problem = ck_tile::BlockFmhaFwdDecodePipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - fmha_trait>; - -using fmha_pipeline = {F_pipeline}< - fmha_pipeline_problem>; - -/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving -/// store_tile_raw() data corruption issue -using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - false, false>>; - -using fmha_kernel = - ck_tile::FmhaFwdDecodeKernel; - -static void run(const ck_tile::stream_config& s, fmha_fwd_decode_args a) -{{ - using k_ = fmha_kernel; - auto [kargs, grids] = fmha_fwd_decode_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; -}} - -using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, - {F_dvpad}>; - -#include - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wtautological-compare" - -namespace {{ -template -void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS - && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> - || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ - if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ - instance::run(s, a); - }} else {{ - instance::run(s, a); - }} - }} else {{ - instance::run(s, a); - }} - // instance::run(s, a); -}} -}} // anonymous namespace - -#pragma clang diagnostic pop - -template<> -void fmha_fwd_decode_oneshot_(const ck_tile::stream_config& s, fmha_fwd_decode_args a) -{{ - if constexpr({F_mode} == false) {{ // batch mode - // we don't check every seqlen_k values for kvcache - if (a.seqlen_k_ptr != nullptr) {{ - run_instance(s, a); - // make sure F_bn0 is divisible by F_bk1 - }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - run_instance(s, a); - }} else {{ - run_instance(s, a); - }} - }} else {{ - run_instance(s, a); - }} - // run_instance(s, a); -}} - -template<> -std::string fmha_fwd_decode_get_name_() -{{ - using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type - return k_::GetName(); -}} -""" - -FMHA_FWD_DECODE_COMBINE_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -namespace {{ -template -struct instance {{ -using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, - {F_dvpad}, - {F_lse}, - {F_squant}, - kLogMaxSplits, - {F_occupancy}>; - -using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - {F_hdim}, - {F_mode}, - {F_bn1}, - fmha_trait>; - -using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< - fmha_pipeline_problem>; - -/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving -/// store_tile_raw() data corruption issue -using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - false, false>>; - -using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel; - -static void run(const ck_tile::stream_config& s, fmha_fwd_decode_args a) -{{ - using k_ = fmha_kernel; - auto [kargs, grids] = fmha_fwd_decode_combine_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; -}} - -using trait_{F_idx} = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, - {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; - -#include - -template<> -void fmha_fwd_decode_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_decode_args a) -{{ - if (a.num_splits <= 8) {{ - instance<3>::run(s, a); - }} else if (a.num_splits <= 16) {{ - instance<4>::run(s, a); - }} else if (a.num_splits <= 32) {{ - instance<5>::run(s, a); - }} else if (a.num_splits <= 64) {{ - instance<6>::run(s, a); - }} else if (a.num_splits <= 128) {{ - instance<7>::run(s, a); - }} -}} - -template<> -std::string fmha_fwd_decode_combine_get_name_() -{{ - using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type - return k_::GetName(); -}} -""" - -FMHA_FWD_DECODE_API_FILENAME="fmha_fwd_decode_api.cpp" -FMHA_FWD_DECODE_API=""" -#include - -template -float fmha_fwd_decode_(const ck_tile::stream_config& s, fmha_fwd_decode_args a) -{{ - if(s.log_level_ > 0) - {{ - if (a.num_splits > 1) - {{ - std::cout - << ", " << fmha_fwd_decode_get_name_() - << ", " << fmha_fwd_decode_combine_get_name_() - << std::flush; - }} - else{{ - std::cout - << ", " << fmha_fwd_decode_get_name_() - << std::flush; - }} - }} - - // we don't need combine kernel when we don't split the kv. - if (a.num_splits > 1) - {{ - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_decode_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_decode_combine_oneshot_(s_, a); }} - ); - }} - else{{ - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_decode_oneshot_(s_, a); }} - ); - }} -}} - -float fmha_fwd_decode(fmha_fwd_decode_traits t, fmha_fwd_decode_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck})&& ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - - // get combine kernel tile sizes - using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; - constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; - - // make sure we can reuse the padding flags in combine kernels - static_assert({F_bm0} % kM0 == 0); - static_assert({F_bn1} % 32 == 0); - - if (t.has_lse) {{ - if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ - return -1; - }} else {{ - using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; - - return fmha_fwd_decode_(s, a); - }} - }} else {{ - using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; - - return fmha_fwd_decode_(s, a); - }} - }} -""" - -@dataclass -class FmhaFwdDecodeApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - mask : str - logits : str - bias : str # - lse : str # - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - pagedkv : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ - f'{self.dvpad}-{self.pagedkv}' - - # sequence length as non-fast-changing dimension, we can always relay on instruction level OOB guard - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']: - if self.spad == 't' : return 'true' - else : return 'true' - else: assert False - - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: - return f'a.seqlen_q <= {self.bm0}' - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']: - if self.skpad == 't' : return 'true' # TODO: order of get_pipelines() matters! (ugly) - else : return 'true' - else: assert False - - # head dimension as fast-changing dimension, we assume is multiple of 8 - @property - def dcheck(self) -> str: - if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - # if self.skpad == 't' : return 'true' - # else : return 'true' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - # if self.skpad == 't' : return 'true' - # else : return 'true' - else: assert False - -@dataclass -class FmhaFwdDecodePipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_squant : str # - F_pagedkv : str # t/f - F_mask : str # value from MASK_MAP - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_pagedkv == 't' : n += '_pagedkv' - else: n += '_npagedkv' - return n - -@dataclass -class FmhaFwdSplitKVCombinePipeline: - tag : str - - F_spad : str # true/false - F_dvpad : str # - F_lse : str # - F_squant : str # - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - return n - -class FmhaFwdDecodeApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdDecodeApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_DECODE_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_scheck=trait.scheck,F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_DECODE_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdSplitKVCombineTileSize: - F_bn1 : int # tile size along v head_dim - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bn1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdDecodeKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdDecodePipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_DECODE_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_FWD_DECODE_PIPELINE_MAP[self.F_pipeline.tag]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_decode_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdDecodeApiTrait: - return FmhaFwdDecodeApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - logits=self.F_pipeline.F_logits, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - squant=self.F_pipeline.F_squant, - pagedkv=self.F_pipeline.F_pagedkv, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - -@dataclass -class FmhaFwdSplitKVCombineKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdSplitKVCombineTileSize - F_pipeline : FmhaFwdSplitKVCombinePipeline - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_DECODE_COMBINE_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bn1 = self.F_tile.F_bn1, - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_mode = MODE_MAP[self.F_mode]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_decode_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '64': { - # Specialize for different SeqQ - '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - }, - '128': { - '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128': FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '256': FmhaFwdTileSize(256, 64, 32, 128, 16, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1), - }, - } - else: - return None - -def get_fmha_fwd_decode_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - # '32' : FmhaFwdSplitKVCombineTileSize(32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - # '256' : FmhaFwdSplitKVCombineTileSize(32, -1), - } - else: - return None - -def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdDecodeApiPool, List[FmhaFwdDecodeKernel]]: - Pipeline = FmhaFwdDecodePipeline - Kernel = FmhaFwdDecodeKernel - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdDecodePipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["f"]): - for lse in ['t', 'f']: - if hdim in [64, 128]: ### [32, 64, 96, 128]: - pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask)) - else: - assert False - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdDecodeApiPool(mask_impl) - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode, seqlenq in itertools.product(d.keys(), MODE_MAP.keys(), SEQLENQ_MAP.keys()): - tile = d[hdim_str][seqlenq] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): - continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def get_fwd_decode_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: - Pipeline = FmhaFwdSplitKVCombinePipeline - Kernel = FmhaFwdSplitKVCombineKernel - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - # for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): - for spad, dvpad, lse in itertools.product(["f"], ["t", "f"], ["t", "f"]): - pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) - elif dtype in ['fp8', 'bf8']: - # no need lse kernels - pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) - else: - assert False - return pipelines - - gen = list() - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_decode_combine_tile_dict_from_dtype(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Aiter(mha_varlen_fwd) integration - if receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) - - return gen - -def write_single_kernel(kernel: Union[FmhaFwdDecodeKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_decode_api(api_pool : FmhaFwdDecodeApiPool, autogen_dir: Path) -> None: - file_path = autogen_dir / FMHA_FWD_DECODE_API_FILENAME - file_path.write_text(api_pool.api) - -def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - - kernels = get_fwd_decode_combine_blobs(filter_list[0], receipt) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_decode_blobs(filter_list[1], receipt, mask_impl) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_fwd_decode_api(api_pool, output_dir) - -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - - with file_path.open('a') as f: - kernels = get_fwd_decode_combine_blobs(filter_list[0], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_decode_blobs(filter_list[1], receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_DECODE_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 62cd8538ee..d873388876 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1078,8 +1078,7 @@ bool run(const ck_tile::ArgParser& arg_parser) args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } } - else if constexpr(std::is_same_v> || - std::is_same_v>) + else if constexpr(std::is_same_v>) { args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 59b4906cbd..df1e9e5699 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -473,95 +473,6 @@ struct fmha_batch_prefill_args drop_seed_offset; }; -struct fmha_fwd_decode_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; // bias or alibi_slope pointer - void* lse_acc_ptr; - void* o_acc_ptr; - void* lse_ptr; - void* o_ptr; - - void* block_table_ptr; - ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr - ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not - // nullptr. - - const void* cache_batch_idx; - - // the real seqlen_q & seqlen_k are decided by following: - // batch mode: seqlen_q = kargs.seqlen_q - // seqlen_k = kargs.seqlen_k - // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] - // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // or kargs.seqlen_k_ptr[b] - // - // batch mode (kvcache): - // seqlen_q = kargs.seqlen_q - // seqlen_k = kargs.seqlen_k_ptr[b] - // group mode (kvcache): - // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] - // - // when is_gappy=true: - // seqlen_k = kargs.seqlen_k_ptr[b] - // seqstart_k_ptr[b] now store local offset of each batch - // - // when is_gappy=false: - // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // or kargs.seqlen_k_ptr[b] - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - ck_tile::index_t num_splits; - - float scale_s; - float scale_p; - float scale_o; - - float logits_soft_cap; - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_o_acc; - ck_tile::index_t stride_o; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_lse; - ck_tile::index_t nhead_stride_lse_acc; - ck_tile::index_t nhead_stride_o_acc; - ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_lse; - ck_tile::index_t batch_stride_lse_acc; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t batch_stride_o; - ck_tile::index_t split_stride_lse_acc; - ck_tile::index_t split_stride_o_acc; - - ck_tile::index_t window_size_left; - ck_tile::index_t window_size_right; - ck_tile::index_t mask_type; -}; - template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -940,168 +851,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) return ck_tile::make_tuple(kargs, grids); } -template -auto fmha_fwd_decode_create_kargs_and_grids(fmha_fwd_decode_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(Kernel::kIsGroupMode) - { - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.is_gappy, - args.scale_s, - args.scale_p, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_k, // only used for paged-kvcache - args.batch_stride_v, // only used for paged-kvcache - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - else - { // create batch mode kernel arguments - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - // args.o_acc_ptr, - args.o_ptr, // hardcoding - args.batch, - args.seqlen_q, - args.seqlen_k, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.cache_batch_idx, - args.scale_s, - args.scale_p, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - }(); - - dim3 grids = Kernel::GridSize( - args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); - - return ck_tile::make_tuple(kargs, grids); -} - -template -auto fmha_fwd_decode_combine_create_kargs_and_grids(fmha_fwd_decode_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel argumentszs - if constexpr(Kernel::kIsGroupMode) - { - return Kernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.seqstart_q_ptr, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o_acc, - args.stride_o, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.nhead_stride_lse, - args.nhead_stride_o, - args.split_stride_lse_acc, - args.split_stride_o_acc); - } - else - { // create batch mode kernel arguments - return Kernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.seqlen_q, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o_acc, - args.stride_o, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, - args.batch_stride_lse, - args.batch_stride_o, - args.split_stride_lse_acc, - args.split_stride_o_acc); - } - }(); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - - return ck_tile::make_tuple(kargs, grids); -} - template auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) { @@ -1441,84 +1190,6 @@ void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_s template std::string fmha_fwd_splitkv_combine_get_name_(); -template -struct fmha_fwd_decode_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr ck_tile::index_t kK1 = kK1_; - static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; - static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsPagedKV = kIsPagedKV_; -}; - -template -void fmha_fwd_decode_oneshot_(const ck_tile::stream_config&, fmha_fwd_decode_args); - -template -std::string fmha_fwd_decode_get_name_(); - -template -struct fmha_fwd_decode_combine_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadDv = kPadDv_; -}; - -template -void fmha_fwd_decode_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_decode_args); - -template -std::string fmha_fwd_decode_combine_get_name_(); - // this is used to pattern-match internl kernel implementation, not to instantiate kernel template using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else -#if 1 // ROCm 7.0 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ + (HIP_VERSION_MAJOR >= 7) using bfloat16_t = __bf16; #else using bfloat16_t = ushort; diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5fc35fc155..32fcd2ec36 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -18,7 +18,6 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -50,8 +49,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp deleted file mode 100644 index 7a744df825..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp +++ /dev/null @@ -1,1334 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/variants.hpp" - -#include -#include - -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] -// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] -// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] -// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] - -// can remove all bank conflicts, but drop the performance for some cases -// Probably it is limited by compiler optimization. -#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 -namespace ck_tile { - -template -struct FmhaFwdDecodeKernel -{ - using FmhaPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; - static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; - static_assert(kBlockPerCu > 0); - static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; - - using QDataType = ck_tile::remove_cvref_t; - using KDataType = ck_tile::remove_cvref_t; - using VDataType = ck_tile::remove_cvref_t; - using BiasDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; - using SaccDataType = ck_tile::remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - - using VLayout = ck_tile::remove_cvref_t; - - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; - static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; - static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; - static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr bool kMergeNumHeadGroupsSeqLenQ = - FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; - using AttentionVariant = ck_tile::remove_cvref_t; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kHasMask = FmhaMask::IsMasking; - - static_assert(!kMergeNumHeadGroupsSeqLenQ || - (kMergeNumHeadGroupsSeqLenQ && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - !kHasMask)); - - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - // clang-format on - - __host__ static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_decode_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); - #undef _SS_ - #undef _TS_ - // clang-format on - } - - template // to avoid duplicated base class prblem, introduce an template - // arg - struct EmptyKargs - { - }; - - // kargs use aggregate initializer, so no constructor will provided - // use inheritance to minimize karg size - // user need to use MakeKargs() function to create kargs. - struct CommonKargs - { - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - void* lse_acc_ptr; - void* o_acc_ptr; - - ck_tile::index_t batch; - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - - ck_tile::index_t num_head_q; - // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k - // if this param is larger than 1, indicate MQA/GQA case - ck_tile::index_t nhead_ratio_qk; - ck_tile::index_t num_splits; - - float scale_s; - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_o_acc; - - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_lse_acc; - ck_tile::index_t nhead_stride_o_acc; - - ck_tile::index_t split_stride_lse_acc; - ck_tile::index_t split_stride_o_acc; - }; - - struct LogitsSoftCapKargs - { - LogitsSoftCapKargs() = default; - - void init_logits_soft_cap(float logits_soft_cap_) - { - if(0 < logits_soft_cap_) - { - logits_soft_cap = logits_soft_cap_; - logits_soft_cap_rcp = 1.f / logits_soft_cap; - } - else - { - logits_soft_cap = 0.f; - logits_soft_cap_rcp = 0.f; - } - } - - float logits_soft_cap; - float logits_soft_cap_rcp; - }; - - struct CommonBiasKargs - { - const void* bias_ptr = nullptr; - ck_tile::index_t stride_bias = 0; - ck_tile::index_t nhead_stride_bias = 0; - }; - - struct BatchModeBiasKargs : CommonBiasKargs - { - ck_tile::index_t batch_stride_bias = 0; - }; - - struct AlibiKargs - { - // alibi is batch*nhead*1, no matter in batch/group mode, they are the same - const void* alibi_slope_ptr; - ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope - }; - - struct MaskKargs - { - // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; - ck_tile::GenericAttentionMaskEnum mask_type; - }; - - struct Fp8StaticQuantKargs - { - float scale_p; - }; - - struct CommonPageBlockTableKargs - { - const int32_t* block_table_ptr; - ck_tile::index_t batch_stride_block_table; - ck_tile::index_t page_block_size; - }; - - struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs - { - bool is_gappy = false; - }; - - struct CacheBatchIdxKargs - { - const int32_t* cache_batch_idx; - }; - - struct BatchModeKargs - : CommonKargs, - std::conditional_t>>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t, - std::conditional_t> - { - const int32_t* seqlen_k_ptr; - - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for - // single kcache page-block - ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for - // single vcache page-block - ck_tile::index_t batch_stride_lse_acc; - ck_tile::index_t batch_stride_o_acc; - }; - - struct GroupModeKargs - : CommonKargs, - std::conditional_t>>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> - { - const int32_t* seqstart_q_ptr; - const int32_t* seqstart_k_ptr; - const int32_t* seqlen_k_ptr; - - ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size - // for single kcache page-block - ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size - // for single vcache page-block - }; - - using Kargs = std::conditional_t; - - struct BlockIndices - { - ck_tile::index_t batch_idx; - ck_tile::index_t qo_head_idx; - ck_tile::index_t kv_head_idx; - }; - - template - __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise - final lse */ - void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final - o */ - ck_tile::index_t batch, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified - const void* seqlen_k_ptr, // only used for (paged-) kvcache - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - ck_tile::index_t num_splits, - const void* block_table_ptr, - ck_tile::index_t batch_stride_block_table, - ck_tile::index_t page_block_size, - const void* cache_batch_idx, - float scale_s, - float scale_p, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_o_acc, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_lse_acc, - ck_tile::index_t nhead_stride_o_acc, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_lse_acc, - ck_tile::index_t batch_stride_o_acc, - ck_tile::index_t split_stride_lse_acc, - ck_tile::index_t split_stride_o_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - lse_acc_ptr, - o_acc_ptr, - batch, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - num_splits, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), -#else - scale_s, -#endif - stride_q, - stride_k, - stride_v, - stride_o_acc, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse_acc, - nhead_stride_o_acc, - split_stride_lse_acc, - split_stride_o_acc}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for fp8_static_quant args - {}, // placeholder for paged-block table or cache_batch_idx - {}, // placeholder for logits_soft_cap - reinterpret_cast(seqlen_k_ptr), - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_lse_acc, - batch_stride_o_acc}; - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - kargs.batch_stride_bias = batch_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } - if constexpr(kHasMask) - { - kargs.window_size_left = window_size_left; - kargs.window_size_right = window_size_right; - kargs.mask_type = static_cast(mask_type); - } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - } - if constexpr(kIsPagedKV) - { - kargs.block_table_ptr = reinterpret_cast(block_table_ptr); - kargs.batch_stride_block_table = batch_stride_block_table; - kargs.page_block_size = page_block_size; - } - else - { - kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } - - return kargs; - } - - template - __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise - final lse */ - void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final - o */ - ck_tile::index_t batch, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - ck_tile::index_t num_splits, - const void* block_table_ptr, - ck_tile::index_t batch_stride_block_table, - ck_tile::index_t page_block_size, - bool is_gappy, - float scale_s, - float scale_p, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_o_acc, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_lse_acc, - ck_tile::index_t nhead_stride_o_acc, - ck_tile::index_t batch_stride_k, // only used for paged-kvcache - ck_tile::index_t batch_stride_v, // only used for paged-kvcache - ck_tile::index_t split_stride_lse_acc, - ck_tile::index_t split_stride_o_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - lse_acc_ptr, - o_acc_ptr, - batch, - -1, // seqlen_q will be updated by another pointer - -1, // seqlen_k will be updated by another pointer - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - num_splits, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), -#else - scale_s, -#endif - stride_q, - stride_k, - stride_v, - stride_o_acc, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse_acc, - nhead_stride_o_acc, - split_stride_lse_acc, - split_stride_o_acc}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for fp8_static_quant args - {}, // placeholder for paged-block table - {}, // placeholder for logits_soft_cap - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr), - batch_stride_k, - batch_stride_v}; - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } - if constexpr(kHasMask) - { - kargs.window_size_left = window_size_left; - kargs.window_size_right = window_size_right; - kargs.mask_type = static_cast(mask_type); - } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - } - if constexpr(kIsPagedKV) - { - kargs.block_table_ptr = reinterpret_cast(block_table_ptr); - kargs.batch_stride_block_table = batch_stride_block_table; - kargs.page_block_size = page_block_size; - kargs.is_gappy = is_gappy; - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } - - return kargs; - } - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead_q, - ck_tile::index_t nhead_kv, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_splits) - { - ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q; - ck_tile::index_t max_seqlen_q_ = - max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1); - - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, - nhead_, - batch_size); - } - - CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) - { - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits); - const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - if constexpr(kHasMask) - { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple( - (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch); - } - else - { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); - } - } - - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - // TODO: Refine the logical here. - // In Decode case - // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache - // 2. limit the LDS usage, as we want higher occupancy - // In Prefill case - // 1. we expect KV data reused by different ThreadGroups, use cache - // 2. use more LDS, as we want better memory latency hiding - // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the cache - constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; - // divide problem - const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; - const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; // unused for paged-kvcache - long_index_t batch_offset_v = 0; // unused for paged-kvcache - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse_acc = 0; - long_index_t batch_offset_o_acc = 0; - // index_t kv_l2p_offset = - // 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache - - if constexpr(kIsGroupMode) - { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - - batch_offset_lse_acc = query_start; - batch_offset_o_acc = query_start * kargs.stride_o_acc; - - // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; - } - - if constexpr(kIsPagedKV) - { - if(kargs.is_gappy) - { - // seqstart_k_ptr has different meaning in this case - // kv_l2p_offset = kargs.seqstart_k_ptr[i_batch]; - } - } - } - else - { - const index_t i_cache_batch = [&, i_batch_ = i_batch] { - if constexpr(kIsPagedKV) - { - return i_batch_; - } - else - { - return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] - : i_batch_); - } - }(); - - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; - batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; - batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - } - - // for simplicity, batch stride we just modify the pointer - const index_t i_nhead_k = - (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk); - - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * - (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * - kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_v + - batch_offset_v; - - ODataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + - static_cast(i_nhead) * - (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * - kargs.nhead_stride_o_acc + - batch_offset_o_acc + i_split * kargs.split_stride_o_acc; - - // Q/K/V DRAM and DRAM window - const auto q_dram = [&] { - const auto q_dram_naive = [&] { - if constexpr(kMergeNumHeadGroupsSeqLenQ) - { - // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q, - // hdim_q) - // We expect Q data reuse among different KVSplited in decode case. - const auto view = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1), - number{}, - number<1>{}); - - return transform_tensor_view( - view, - make_tuple( - make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)), - make_pass_through_transform(kargs.hdim_q)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else - { - return make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - } - }(); - - if constexpr(FmhaPipeline::kQLoadOnce) - { - const auto seqlen_q = kargs.seqlen_q; - const auto q_dram_pad = pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple(make_unmerge_transform( - make_tuple(seqlen_q / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_merged = transform_tensor_view( - q_dram_unmerged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto q_dram_unmerged_xor = transform_tensor_view( - q_dram_merged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged_xor, - make_tuple( - make_xor_transform( - make_tuple(seqlen_q / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_tmp = transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - q_dram_tmp, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(seqlen_q / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_pass_through_transform(seqlen_q), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged, - make_tuple( - make_xor_transform(make_tuple( - seqlen_q, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto make_k_dram = [&](const KDataType* data, index_t height) { - const auto k_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(height, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - const auto k_dram_pad = pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple( - make_unmerge_transform(make_tuple(height / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_merged = transform_tensor_view( - k_dram_unmerged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_merge_transform_v3_division_mod( - make_tuple(XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto k_dram_unmerged_xor = transform_tensor_view( - k_dram_merged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged_xor, - make_tuple(make_xor_transform( - make_tuple(height / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_tmp = transform_tensor_view( - k_dram_permuted, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - k_dram_tmp, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(height / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple(make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged, - make_tuple(make_xor_transform(make_tuple( - height, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - k_dram_permuted, - make_tuple(make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - const auto k_dram = [&]() { - if constexpr(kIsPagedKV) - { - return make_k_dram(nullptr, kargs.page_block_size); - } - else - { - return make_k_dram(k_ptr, kargs.seqlen_k); - } - }(); - - const auto make_v_dram = [&](const VDataType* data, index_t length) { - const auto v_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(length, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), - number{}, - number<1>{}); - - // TODO: Add kVHeadDim - constexpr index_t XorGroupSize = - FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); - - const auto v_dram_pad = pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple( - make_unmerge_transform(make_tuple(length / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_merged = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_merge_transform_v3_division_mod( - make_tuple(XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto v_dram_unmerged_xor = transform_tensor_view( - v_dram_merged, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged_xor, - make_tuple(make_xor_transform(make_tuple( - length / XorLengthFold, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_tmp = transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform( - make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - v_dram_tmp, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(length / XorLengthFold, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple(make_pass_through_transform(length), - make_unmerge_transform( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_xor_transform(make_tuple( - length, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - - const auto v_dram = [&]() { - if constexpr(kIsPagedKV) - { - return make_v_dram(nullptr, kargs.page_block_size); - } - else - { - return make_v_dram(v_ptr, kargs.seqlen_k); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); - - auto v_dram_window = make_tile_window( - v_dram, make_tuple(number{}, number{}), {0, 0}); - - /// FIXME: Before C++20, capturing structured binding variables are not supported. - /// Remove following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - bias_dram_naive, bias_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse acc - auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { - constexpr auto lse_acc_dram_window_lengths = make_tuple(number{}); - LSEDataType* lse_acc_ptr = reinterpret_cast(kargs.lse_acc_ptr) + - static_cast(i_nhead_) * - (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * - kargs.nhead_stride_lse_acc + - batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; - - const auto lse_acc_dram = [&] { - const auto lse_acc_dram_naive = [&] { - if constexpr(kMergeNumHeadGroupsSeqLenQ) - { - // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q) - const auto view = make_naive_tensor_view( - lse_acc_ptr, - make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q), - make_tuple(kargs.nhead_stride_lse_acc, 1), - number<1>{}, - number<1>{}); - - return transform_tensor_view(view, - make_tuple(make_merge_transform(make_tuple( - kargs.nhead_ratio_qk, kargs.seqlen_q))), - make_tuple(sequence<0, 1>{}), - make_tuple(sequence<0>{})); - } - else - { - return make_naive_tensor_view( - lse_acc_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - } - }(); - return pad_tensor_view( - lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); - - auto o_acc_tile = [&, i_split_ = i_split]() { - if constexpr(PrefillCase) - { - // allocate double lds - // add __restrict__ here to avoid aliasing - __shared__ char - smem_ptrk0[FmhaPipeline::Policy:: - template GetSmemSizeK()]; - __shared__ char - smem_ptrk1[FmhaPipeline::Policy:: - template GetSmemSizeK()]; - __shared__ char smem_ptrv0 - [FmhaPipeline::Policy::template GetSmemSizeV()]; - __shared__ char smem_ptrv1 - [FmhaPipeline::Policy::template GetSmemSizeV()]; - - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_acc_dram_window, - kargs.num_splits, // Remove it - i_split_, // Remove it - mask, - position_encoding, - kargs.scale_s, - smem_ptrk0, - smem_ptrk1, - smem_ptrv0, - smem_ptrv1); - } - else - { - __shared__ char smem_ptr[GetSmemSize()]; - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_acc_dram_window, - kargs.num_splits, // Remove it - i_split_, // Remove it - mask, - position_encoding, - kargs.scale_s, - smem_ptr); - } - }(); - - // Oacc DRAM and Oacc DRAM window - auto o_acc_dram = [&] { - const auto o_acc_dram_naive = [&] { - if constexpr(kMergeNumHeadGroupsSeqLenQ) - { - // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * - // seqlen_q, hdim_v) - const auto view = make_naive_tensor_view( - o_acc_ptr, - make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1), - number{}, - number<1>{}); - - return transform_tensor_view( - view, - make_tuple( - make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)), - make_pass_through_transform(kargs.hdim_v)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else - { - return make_naive_tensor_view( - o_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o_acc, 1), - number{}, - number<1>{}); - } - }(); - - return pad_tensor_view( - o_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto o_acc_dram_window = - make_tile_window(o_acc_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index ece3306604..45a1c8f4b8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -11,7 +11,7 @@ enum class BlockFmhaPipelineEnum QRKSVS = 0, QRKSVS_ASYNC, QSKSVS, - DECODE_QRKSVS, + QRKSVS_ASYNC_TRLOAD, }; template @@ -34,9 +34,9 @@ struct BlockFmhaPipelineEnumToStr }; template <> -struct BlockFmhaPipelineEnumToStr +struct BlockFmhaPipelineEnumToStr { - static constexpr const char* name = "decode_qr"; + static constexpr const char* name = "qr_async_trload"; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 07778355a5..86ac713b6f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -222,113 +222,6 @@ struct BlockFmhaSplitKVCombinePipelineProblem (kM0 * kMaxSplits) % get_warp_size() == 0); }; -template -struct BlockFmhaFwdDecodePipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using AttentionVariant = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; - static constexpr bool kIsPagedKV = Traits::kIsPagedKV; - static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; - static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - -// extract tile size attributes to remove dependency on traits -template -struct BlockFmhaDecodeCombinePipelineTileSizes -{ - static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_); - - static constexpr index_t kN1 = kN1_; - static constexpr index_t NThreads = kN1 / MaxVectorSize; - static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp -}; - -template -struct BlockFmhaDecodeCombinePipelineProblem - : BlockFmhaDecodeCombinePipelineTileSizes -{ - using BaseType = BlockFmhaDecodeCombinePipelineTileSizes; - - using LSEDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using Traits = remove_cvref_t; - - static_assert(std::is_same_v); - - static constexpr index_t kHeadDimV = HeadDimV_; - static constexpr bool kIsGroupMode = kIsGroupMode_; - - using BaseType::kM0; - using BaseType::kN1; - - static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0); - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr index_t kMaxSplits = Traits::kMaxSplits; - static_assert(8 <= kMaxSplits); - - static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup - static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); - - static_assert(get_warp_size() <= (kM0 * kMaxSplits) && - (kM0 * kMaxSplits) % get_warp_size() == 0); -}; - template -struct BlockFmhaFwdDecodePipelineQRKSVS +template +struct BlockFmhaPipelineQRKSVSAsyncTrload { static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp similarity index 99% rename from include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp rename to include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp index 8839c419cd..172fa9116b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -18,7 +18,7 @@ #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 namespace ck_tile { // This pipeline is qkv all located in LDS -struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy +struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy : BlockFmhaPipelineQXKSVSCustomPolicy 1 or fwd training is running */ - bool kDoFp8StaticQuant_, - bool kIsPagedKV_, - bool kHasUnevenSplits_, - bool kMergeNumHeadGroupsSeqLenQ_ = false, - index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> -struct TileFmhaFwdDecodeTraits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kIsPagedKV = kIsPagedKV_; - // determine if some split (length) is not divisible by tile size - static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; - static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - -template -struct TileFmhaFwdDecodeCombineTraits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - - static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_); - static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0); - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - template